mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-04 06:32:13 +08:00
* fix: improve JSON repair handling for markdown code blocks * unified import path * compress_crawl_udf * fix * reverse
1435 lines
59 KiB
Python
1435 lines
59 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from functools import partial
|
|
from typing import Annotated, Any, Literal
|
|
|
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
from langchain_core.tools import tool
|
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
from langgraph.errors import GraphRecursionError
|
|
from langgraph.types import Command, interrupt
|
|
|
|
from src.agents import create_agent
|
|
from src.citations import extract_citations_from_messages, merge_citations
|
|
from src.config.agents import AGENT_LLM_MAP
|
|
from src.config.configuration import Configuration
|
|
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
|
|
from src.prompts.planner_model import Plan
|
|
from src.prompts.template import apply_prompt_template, get_system_prompt_template
|
|
from src.tools import (
|
|
crawl_tool,
|
|
get_retriever_tool,
|
|
get_web_search_tool,
|
|
python_repl_tool,
|
|
)
|
|
from src.tools.search import LoggedTavilySearch
|
|
from src.utils.context_manager import ContextManager, validate_message_content
|
|
from src.utils.json_utils import repair_json_output, sanitize_tool_response
|
|
|
|
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
|
from .types import State
|
|
from .utils import (
|
|
build_clarified_topic_from_history,
|
|
get_message_content,
|
|
is_user_message,
|
|
reconstruct_clarification_history,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@tool
|
|
def handoff_to_planner(
|
|
research_topic: Annotated[str, "The topic of the research task to be handed off."],
|
|
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
|
):
|
|
"""Handoff to planner agent to do plan."""
|
|
# This tool is not returning anything: we're just using it
|
|
# as a way for LLM to signal that it needs to hand off to planner agent
|
|
return
|
|
|
|
|
|
@tool
|
|
def handoff_after_clarification(
|
|
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
|
research_topic: Annotated[
|
|
str, "The clarified research topic based on all clarification rounds."
|
|
],
|
|
):
|
|
"""Handoff to planner after clarification rounds are complete. Pass all clarification history to planner for analysis."""
|
|
return
|
|
|
|
|
|
@tool
|
|
def direct_response(
|
|
message: Annotated[str, "The response message to send directly to user."],
|
|
locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."],
|
|
):
|
|
"""Respond directly to user for greetings, small talk, or polite rejections. Do NOT use this for research questions - use handoff_to_planner instead."""
|
|
return
|
|
|
|
|
|
def needs_clarification(state: dict) -> bool:
|
|
"""
|
|
Check if clarification is needed based on current state.
|
|
Centralized logic for determining when to continue clarification.
|
|
"""
|
|
if not state.get("enable_clarification", False):
|
|
return False
|
|
|
|
clarification_rounds = state.get("clarification_rounds", 0)
|
|
is_clarification_complete = state.get("is_clarification_complete", False)
|
|
max_clarification_rounds = state.get("max_clarification_rounds", 3)
|
|
|
|
# Need clarification if: enabled + has rounds + not complete + not exceeded max
|
|
# Use <= because after asking the Nth question, we still need to wait for the Nth answer
|
|
return (
|
|
clarification_rounds > 0
|
|
and not is_clarification_complete
|
|
and clarification_rounds <= max_clarification_rounds
|
|
)
|
|
|
|
|
|
def preserve_state_meta_fields(state: State) -> dict:
|
|
"""
|
|
Extract meta/config fields that should be preserved across state transitions.
|
|
|
|
These fields are critical for workflow continuity and should be explicitly
|
|
included in all Command.update dicts to prevent them from reverting to defaults.
|
|
|
|
Args:
|
|
state: Current state object
|
|
|
|
Returns:
|
|
Dict of meta fields to preserve
|
|
"""
|
|
return {
|
|
"locale": state.get("locale", "en-US"),
|
|
"research_topic": state.get("research_topic", ""),
|
|
"clarified_research_topic": state.get("clarified_research_topic", ""),
|
|
"clarification_history": state.get("clarification_history", []),
|
|
"enable_clarification": state.get("enable_clarification", False),
|
|
"max_clarification_rounds": state.get("max_clarification_rounds", 3),
|
|
"clarification_rounds": state.get("clarification_rounds", 0),
|
|
"resources": state.get("resources", []),
|
|
}
|
|
|
|
|
|
def validate_and_fix_plan(plan: dict, enforce_web_search: bool = False, enable_web_search: bool = True) -> dict:
|
|
"""
|
|
Validate and fix a plan to ensure it meets requirements.
|
|
|
|
Args:
|
|
plan: The plan dict to validate
|
|
enforce_web_search: If True, ensure at least one step has need_search=true
|
|
enable_web_search: If False, skip web search enforcement (takes precedence)
|
|
|
|
Returns:
|
|
The validated/fixed plan dict
|
|
"""
|
|
if not isinstance(plan, dict):
|
|
return plan
|
|
|
|
steps = plan.get("steps", [])
|
|
|
|
# ============================================================
|
|
# SECTION 1: Repair missing step_type fields (Issue #650 fix)
|
|
# ============================================================
|
|
for idx, step in enumerate(steps):
|
|
if not isinstance(step, dict):
|
|
continue
|
|
|
|
# Check if step_type is missing or empty
|
|
if "step_type" not in step or not step.get("step_type"):
|
|
# Infer step_type based on need_search value
|
|
# Default to "analysis" for non-search steps (Issue #677: not all processing needs code)
|
|
inferred_type = "research" if step.get("need_search", False) else "analysis"
|
|
step["step_type"] = inferred_type
|
|
logger.info(
|
|
f"Repaired missing step_type for step {idx} ({step.get('title', 'Untitled')}): "
|
|
f"inferred as '{inferred_type}' based on need_search={step.get('need_search', False)}"
|
|
)
|
|
|
|
# ============================================================
|
|
# SECTION 2: Enforce web search requirements
|
|
# Skip enforcement if web search is disabled (enable_web_search=False takes precedence)
|
|
# ============================================================
|
|
if enforce_web_search and enable_web_search:
|
|
# Check if any step has need_search=true (only check dict steps)
|
|
has_search_step = any(
|
|
step.get("need_search", False)
|
|
for step in steps
|
|
if isinstance(step, dict)
|
|
)
|
|
|
|
if not has_search_step and steps:
|
|
# Ensure first research step has web search enabled
|
|
for idx, step in enumerate(steps):
|
|
if isinstance(step, dict) and step.get("step_type") == "research":
|
|
step["need_search"] = True
|
|
logger.info(f"Enforced web search on research step at index {idx}")
|
|
break
|
|
else:
|
|
# Fallback: If no research step exists, convert the first step to a research step with web search enabled.
|
|
# This ensures that at least one step will perform a web search as required.
|
|
if isinstance(steps[0], dict):
|
|
steps[0]["step_type"] = "research"
|
|
steps[0]["need_search"] = True
|
|
logger.info(
|
|
"Converted first step to research with web search enforcement"
|
|
)
|
|
elif not has_search_step and not steps:
|
|
# Add a default research step if no steps exist
|
|
logger.warning("Plan has no steps. Adding default research step.")
|
|
plan["steps"] = [
|
|
{
|
|
"need_search": True,
|
|
"title": "Initial Research",
|
|
"description": "Gather information about the topic",
|
|
"step_type": "research",
|
|
}
|
|
]
|
|
|
|
return plan
|
|
|
|
|
|
def background_investigation_node(state: State, config: RunnableConfig):
|
|
logger.info("background investigation node is running.")
|
|
configurable = Configuration.from_runnable_config(config)
|
|
|
|
# Background investigation relies on web search; skip entirely when web search is disabled
|
|
if not configurable.enable_web_search:
|
|
logger.info("Web search is disabled, skipping background investigation.")
|
|
return {"background_investigation_results": json.dumps([], ensure_ascii=False)}
|
|
|
|
query = state.get("clarified_research_topic") or state.get("research_topic")
|
|
background_investigation_results = []
|
|
|
|
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
|
searched_content = LoggedTavilySearch(
|
|
max_results=configurable.max_search_results
|
|
).invoke(query)
|
|
# check if the searched_content is a tuple, then we need to unpack it
|
|
if isinstance(searched_content, tuple):
|
|
searched_content = searched_content[0]
|
|
|
|
# Handle string JSON response (new format from fixed Tavily tool)
|
|
if isinstance(searched_content, str):
|
|
try:
|
|
parsed = json.loads(searched_content)
|
|
if isinstance(parsed, dict) and "error" in parsed:
|
|
logger.error(f"Tavily search error: {parsed['error']}")
|
|
background_investigation_results = []
|
|
elif isinstance(parsed, list):
|
|
background_investigation_results = [
|
|
f"## {elem.get('title', 'Untitled')}\n\n{elem.get('content', 'No content')}"
|
|
for elem in parsed
|
|
]
|
|
else:
|
|
logger.error(f"Unexpected Tavily response format: {searched_content}")
|
|
background_investigation_results = []
|
|
except json.JSONDecodeError:
|
|
logger.error(f"Failed to parse Tavily response as JSON: {searched_content}")
|
|
background_investigation_results = []
|
|
# Handle legacy list format
|
|
elif isinstance(searched_content, list):
|
|
background_investigation_results = [
|
|
f"## {elem['title']}\n\n{elem['content']}" for elem in searched_content
|
|
]
|
|
return {
|
|
"background_investigation_results": "\n\n".join(
|
|
background_investigation_results
|
|
)
|
|
}
|
|
else:
|
|
logger.error(
|
|
f"Tavily search returned malformed response: {searched_content}"
|
|
)
|
|
background_investigation_results = []
|
|
else:
|
|
background_investigation_results = get_web_search_tool(
|
|
configurable.max_search_results
|
|
).invoke(query)
|
|
|
|
return {
|
|
"background_investigation_results": json.dumps(
|
|
background_investigation_results, ensure_ascii=False
|
|
)
|
|
}
|
|
|
|
|
|
def planner_node(
|
|
state: State, config: RunnableConfig
|
|
) -> Command[Literal["human_feedback", "reporter"]]:
|
|
"""Planner node that generate the full plan."""
|
|
logger.info("Planner generating full plan with locale: %s", state.get("locale", "en-US"))
|
|
configurable = Configuration.from_runnable_config(config)
|
|
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
|
|
|
# For clarification feature: use the clarified research topic (complete history)
|
|
if state.get("enable_clarification", False) and state.get(
|
|
"clarified_research_topic"
|
|
):
|
|
# Modify state to use clarified research topic instead of full conversation
|
|
modified_state = state.copy()
|
|
modified_state["messages"] = [
|
|
{"role": "user", "content": state["clarified_research_topic"]}
|
|
]
|
|
modified_state["research_topic"] = state["clarified_research_topic"]
|
|
messages = apply_prompt_template("planner", modified_state, configurable, state.get("locale", "en-US"))
|
|
|
|
logger.info(
|
|
f"Clarification mode: Using clarified research topic: {state['clarified_research_topic']}"
|
|
)
|
|
else:
|
|
# Normal mode: use full conversation history
|
|
messages = apply_prompt_template("planner", state, configurable, state.get("locale", "en-US"))
|
|
|
|
if state.get("enable_background_investigation") and state.get(
|
|
"background_investigation_results"
|
|
):
|
|
messages += [
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"background investigation results of user query:\n"
|
|
+ state["background_investigation_results"]
|
|
+ "\n"
|
|
),
|
|
}
|
|
]
|
|
|
|
if configurable.enable_deep_thinking:
|
|
llm = get_llm_by_type("reasoning")
|
|
elif AGENT_LLM_MAP["planner"] == "basic":
|
|
llm = get_llm_by_type("basic")
|
|
else:
|
|
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
|
|
|
|
# if the plan iterations is greater than the max plan iterations, return the reporter node
|
|
if plan_iterations >= configurable.max_plan_iterations:
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="reporter"
|
|
)
|
|
|
|
full_response = ""
|
|
if AGENT_LLM_MAP["planner"] == "basic" and not configurable.enable_deep_thinking:
|
|
response = llm.invoke(messages)
|
|
if hasattr(response, "model_dump_json"):
|
|
full_response = response.model_dump_json(indent=4, exclude_none=True)
|
|
else:
|
|
full_response = get_message_content(response) or ""
|
|
else:
|
|
response = llm.stream(messages)
|
|
for chunk in response:
|
|
full_response += chunk.content
|
|
logger.debug(f"Current state messages: {state['messages']}")
|
|
logger.info(f"Planner response: {full_response}")
|
|
|
|
# Clean the response first to handle markdown code blocks (```json, ```ts, etc.)
|
|
cleaned_response = repair_json_output(full_response)
|
|
|
|
# Validate explicitly that response content is valid JSON before proceeding to parse it
|
|
if not cleaned_response.strip().startswith('{') and not cleaned_response.strip().startswith('['):
|
|
logger.warning("Planner response does not appear to be valid JSON after cleanup")
|
|
if plan_iterations > 0:
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="reporter"
|
|
)
|
|
else:
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="__end__"
|
|
)
|
|
|
|
try:
|
|
curr_plan = json.loads(cleaned_response)
|
|
# Need to extract the plan from the full_response
|
|
curr_plan_content = extract_plan_content(curr_plan)
|
|
# load the current_plan
|
|
curr_plan = json.loads(repair_json_output(curr_plan_content))
|
|
except json.JSONDecodeError:
|
|
logger.warning("Planner response is not a valid JSON")
|
|
if plan_iterations > 0:
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="reporter"
|
|
)
|
|
else:
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="__end__"
|
|
)
|
|
|
|
# Validate and fix plan to ensure web search requirements are met
|
|
if isinstance(curr_plan, dict):
|
|
curr_plan = validate_and_fix_plan(curr_plan, configurable.enforce_web_search, configurable.enable_web_search)
|
|
|
|
if isinstance(curr_plan, dict) and curr_plan.get("has_enough_context"):
|
|
logger.info("Planner response has enough context.")
|
|
new_plan = Plan.model_validate(curr_plan)
|
|
return Command(
|
|
update={
|
|
"messages": [AIMessage(content=full_response, name="planner")],
|
|
"current_plan": new_plan,
|
|
**preserve_state_meta_fields(state),
|
|
},
|
|
goto="reporter",
|
|
)
|
|
return Command(
|
|
update={
|
|
"messages": [AIMessage(content=full_response, name="planner")],
|
|
"current_plan": full_response,
|
|
**preserve_state_meta_fields(state),
|
|
},
|
|
goto="human_feedback",
|
|
)
|
|
|
|
|
|
def extract_plan_content(plan_data: str | dict | Any) -> str:
|
|
"""
|
|
Safely extract plan content from different types of plan data.
|
|
|
|
Args:
|
|
plan_data: The plan data which can be a string, AIMessage, or dict
|
|
|
|
Returns:
|
|
str: The plan content as a string (JSON string for dict inputs, or
|
|
extracted/original string for other types)
|
|
"""
|
|
if isinstance(plan_data, str):
|
|
# If it's already a string, return as is
|
|
return plan_data
|
|
elif hasattr(plan_data, 'content') and isinstance(plan_data.content, str):
|
|
# If it's an AIMessage or similar object with a content attribute
|
|
logger.debug(f"Extracting plan content from message object of type {type(plan_data).__name__}")
|
|
return plan_data.content
|
|
elif isinstance(plan_data, dict):
|
|
# If it's already a dictionary, convert to JSON string
|
|
# Need to check if it's dict with content field (AIMessage-like)
|
|
if "content" in plan_data:
|
|
if isinstance(plan_data["content"], str):
|
|
logger.debug("Extracting plan content from dict with content field")
|
|
return plan_data["content"]
|
|
if isinstance(plan_data["content"], dict):
|
|
logger.debug("Converting content field dict to JSON string")
|
|
return json.dumps(plan_data["content"], ensure_ascii=False)
|
|
else:
|
|
logger.warning(f"Unexpected type for 'content' field in plan_data dict: {type(plan_data['content']).__name__}, converting to string")
|
|
return str(plan_data["content"])
|
|
else:
|
|
logger.debug("Converting plan dictionary to JSON string")
|
|
return json.dumps(plan_data)
|
|
else:
|
|
# For any other type, try to convert to string
|
|
logger.warning(f"Unexpected plan data type {type(plan_data).__name__}, attempting to convert to string")
|
|
return str(plan_data)
|
|
|
|
|
|
def human_feedback_node(
|
|
state: State, config: RunnableConfig
|
|
) -> Command[Literal["planner", "research_team", "reporter", "__end__"]]:
|
|
current_plan = state.get("current_plan", "")
|
|
# check if the plan is auto accepted
|
|
auto_accepted_plan = state.get("auto_accepted_plan", False)
|
|
if not auto_accepted_plan:
|
|
feedback = interrupt("Please Review the Plan.")
|
|
|
|
# Handle None or empty feedback
|
|
if not feedback:
|
|
logger.warning(f"Received empty or None feedback: {feedback}. Returning to planner for new plan.")
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="planner"
|
|
)
|
|
|
|
# Normalize feedback string
|
|
feedback_normalized = str(feedback).strip().upper()
|
|
|
|
# if the feedback is not accepted, return the planner node
|
|
if feedback_normalized.startswith("[EDIT_PLAN]"):
|
|
logger.info(f"Plan edit requested by user: {feedback}")
|
|
return Command(
|
|
update={
|
|
"messages": [
|
|
HumanMessage(content=feedback, name="feedback"),
|
|
],
|
|
**preserve_state_meta_fields(state),
|
|
},
|
|
goto="planner",
|
|
)
|
|
elif feedback_normalized.startswith("[ACCEPTED]"):
|
|
logger.info("Plan is accepted by user.")
|
|
else:
|
|
logger.warning(f"Unsupported feedback format: {feedback}. Please use '[ACCEPTED]' to accept or '[EDIT_PLAN]' to edit.")
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="planner"
|
|
)
|
|
|
|
# if the plan is accepted, run the following node
|
|
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
|
goto = "research_team"
|
|
try:
|
|
# Safely extract plan content from different types (string, AIMessage, dict)
|
|
original_plan = current_plan
|
|
|
|
# Repair the JSON output
|
|
current_plan = repair_json_output(current_plan)
|
|
# parse the plan to dict
|
|
current_plan = json.loads(current_plan)
|
|
current_plan_content = extract_plan_content(current_plan)
|
|
|
|
# increment the plan iterations
|
|
plan_iterations += 1
|
|
# parse the plan
|
|
new_plan = json.loads(repair_json_output(current_plan_content))
|
|
# Validate and fix plan to ensure web search requirements are met
|
|
configurable = Configuration.from_runnable_config(config)
|
|
new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search, configurable.enable_web_search)
|
|
except (json.JSONDecodeError, AttributeError) as e:
|
|
logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}")
|
|
if isinstance(current_plan, dict) and "content" in original_plan:
|
|
logger.warning(f"Plan appears to be an AIMessage object with content field")
|
|
if plan_iterations > 1: # the plan_iterations is increased before this check
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="reporter"
|
|
)
|
|
else:
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="__end__"
|
|
)
|
|
|
|
# Build update dict with safe locale handling
|
|
update_dict = {
|
|
"current_plan": Plan.model_validate(new_plan),
|
|
"plan_iterations": plan_iterations,
|
|
**preserve_state_meta_fields(state),
|
|
}
|
|
|
|
# Only override locale if new_plan provides a valid value, otherwise use preserved locale
|
|
if new_plan.get("locale"):
|
|
update_dict["locale"] = new_plan["locale"]
|
|
|
|
return Command(
|
|
update=update_dict,
|
|
goto=goto,
|
|
)
|
|
|
|
|
|
def coordinator_node(
|
|
state: State, config: RunnableConfig
|
|
) -> Command[Literal["planner", "background_investigator", "coordinator", "__end__"]]:
|
|
"""Coordinator node that communicate with customers and handle clarification."""
|
|
logger.info("Coordinator talking.")
|
|
configurable = Configuration.from_runnable_config(config)
|
|
|
|
# Check if clarification is enabled
|
|
enable_clarification = state.get("enable_clarification", False)
|
|
initial_topic = state.get("research_topic", "")
|
|
clarified_topic = initial_topic
|
|
# ============================================================
|
|
# BRANCH 1: Clarification DISABLED (Legacy Mode)
|
|
# ============================================================
|
|
if not enable_clarification:
|
|
# Use normal prompt with explicit instruction to skip clarification
|
|
messages = apply_prompt_template("coordinator", state, locale=state.get("locale", "en-US"))
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": "Clarification is DISABLED. For research questions, use handoff_to_planner. For greetings or small talk, use direct_response. Do NOT ask clarifying questions.",
|
|
}
|
|
)
|
|
|
|
# Bind both handoff_to_planner and direct_response tools
|
|
tools = [handoff_to_planner, direct_response]
|
|
response = (
|
|
get_llm_by_type(AGENT_LLM_MAP["coordinator"])
|
|
.bind_tools(tools)
|
|
.invoke(messages)
|
|
)
|
|
|
|
goto = "__end__"
|
|
locale = state.get("locale", "en-US")
|
|
logger.info(f"Coordinator locale: {locale}")
|
|
research_topic = state.get("research_topic", "")
|
|
|
|
# Process tool calls for legacy mode
|
|
if response.tool_calls:
|
|
try:
|
|
for tool_call in response.tool_calls:
|
|
tool_name = tool_call.get("name", "")
|
|
tool_args = tool_call.get("args", {})
|
|
|
|
if tool_name == "handoff_to_planner":
|
|
logger.info("Handing off to planner")
|
|
goto = "planner"
|
|
|
|
# Extract research_topic if provided
|
|
if tool_args.get("research_topic"):
|
|
research_topic = tool_args.get("research_topic")
|
|
break
|
|
elif tool_name == "direct_response":
|
|
logger.info("Direct response to user (greeting/small talk)")
|
|
goto = "__end__"
|
|
# Append direct message to messages list instead of overwriting response
|
|
if tool_args.get("message"):
|
|
messages.append(AIMessage(content=tool_args.get("message"), name="coordinator"))
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing tool calls: {e}")
|
|
goto = "planner"
|
|
|
|
# Do not return early - let code flow to unified return logic below
|
|
# Set clarification variables for legacy mode
|
|
clarification_rounds = 0
|
|
clarification_history = []
|
|
clarified_topic = research_topic
|
|
|
|
# ============================================================
|
|
# BRANCH 2: Clarification ENABLED (New Feature)
|
|
# ============================================================
|
|
else:
|
|
# Load clarification state
|
|
clarification_rounds = state.get("clarification_rounds", 0)
|
|
clarification_history = list(state.get("clarification_history", []) or [])
|
|
clarification_history = [item for item in clarification_history if item]
|
|
max_clarification_rounds = state.get("max_clarification_rounds", 3)
|
|
|
|
# Prepare the messages for the coordinator
|
|
state_messages = list(state.get("messages", []))
|
|
messages = apply_prompt_template("coordinator", state, locale=state.get("locale", "en-US"))
|
|
|
|
clarification_history = reconstruct_clarification_history(
|
|
state_messages, clarification_history, initial_topic
|
|
)
|
|
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
|
clarification_history
|
|
)
|
|
logger.debug("Clarification history rebuilt: %s", clarification_history)
|
|
|
|
if clarification_history:
|
|
initial_topic = clarification_history[0]
|
|
latest_user_content = clarification_history[-1]
|
|
else:
|
|
latest_user_content = ""
|
|
|
|
# Add clarification status for first round
|
|
if clarification_rounds == 0:
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": "Clarification mode is ENABLED. Follow the 'Clarification Process' guidelines in your instructions.",
|
|
}
|
|
)
|
|
|
|
current_response = latest_user_content or "No response"
|
|
logger.info(
|
|
"Clarification round %s/%s | topic: %s | current user response: %s",
|
|
clarification_rounds,
|
|
max_clarification_rounds,
|
|
clarified_topic or initial_topic,
|
|
current_response,
|
|
)
|
|
|
|
clarification_context = f"""Continuing clarification (round {clarification_rounds}/{max_clarification_rounds}):
|
|
User's latest response: {current_response}
|
|
Ask for remaining missing dimensions. Do NOT repeat questions or start new topics."""
|
|
|
|
messages.append({"role": "system", "content": clarification_context})
|
|
|
|
# Bind both clarification tools - let LLM choose the appropriate one
|
|
tools = [handoff_to_planner, handoff_after_clarification]
|
|
|
|
# Check if we've already reached max rounds
|
|
if clarification_rounds >= max_clarification_rounds:
|
|
# Max rounds reached - force handoff by adding system instruction
|
|
logger.warning(
|
|
f"Max clarification rounds ({max_clarification_rounds}) reached. Forcing handoff to planner. Using prepared clarified topic: {clarified_topic}"
|
|
)
|
|
# Add system instruction to force handoff - let LLM choose the right tool
|
|
messages.append(
|
|
{
|
|
"role": "system",
|
|
"content": f"MAX ROUNDS REACHED. You MUST call handoff_after_clarification (not handoff_to_planner) with the appropriate locale based on the user's language and research_topic='{clarified_topic}'. Do not ask any more questions.",
|
|
}
|
|
)
|
|
|
|
response = (
|
|
get_llm_by_type(AGENT_LLM_MAP["coordinator"])
|
|
.bind_tools(tools)
|
|
.invoke(messages)
|
|
)
|
|
logger.debug(f"Current state messages: {state['messages']}")
|
|
|
|
# Initialize response processing variables
|
|
goto = "__end__"
|
|
locale = state.get("locale", "en-US")
|
|
research_topic = (
|
|
clarification_history[0]
|
|
if clarification_history
|
|
else state.get("research_topic", "")
|
|
)
|
|
if not clarified_topic:
|
|
clarified_topic = research_topic
|
|
|
|
# --- Process LLM response ---
|
|
# No tool calls - LLM is asking a clarifying question
|
|
if not response.tool_calls and response.content:
|
|
# Check if we've reached max rounds - if so, force handoff to planner
|
|
if clarification_rounds >= max_clarification_rounds:
|
|
logger.warning(
|
|
f"Max clarification rounds ({max_clarification_rounds}) reached. "
|
|
"LLM didn't call handoff tool, forcing handoff to planner."
|
|
)
|
|
goto = "planner"
|
|
# Continue to final section instead of early return
|
|
else:
|
|
# Continue clarification process
|
|
clarification_rounds += 1
|
|
# Do NOT add LLM response to clarification_history - only user responses
|
|
logger.info(
|
|
f"Clarification response: {clarification_rounds}/{max_clarification_rounds}: {response.content}"
|
|
)
|
|
|
|
# Append coordinator's question to messages
|
|
updated_messages = list(state_messages)
|
|
if response.content:
|
|
updated_messages.append(
|
|
HumanMessage(content=response.content, name="coordinator")
|
|
)
|
|
|
|
return Command(
|
|
update={
|
|
"messages": updated_messages,
|
|
"locale": locale,
|
|
"research_topic": research_topic,
|
|
"resources": configurable.resources,
|
|
"clarification_rounds": clarification_rounds,
|
|
"clarification_history": clarification_history,
|
|
"clarified_research_topic": clarified_topic,
|
|
"is_clarification_complete": False,
|
|
"goto": goto,
|
|
"citations": state.get("citations", []),
|
|
"__interrupt__": [("coordinator", response.content)],
|
|
},
|
|
goto=goto,
|
|
)
|
|
else:
|
|
# LLM called a tool (handoff) or has no content - clarification complete
|
|
if response.tool_calls:
|
|
logger.info(
|
|
f"Clarification completed after {clarification_rounds} rounds. LLM called handoff tool."
|
|
)
|
|
else:
|
|
logger.warning("LLM response has no content and no tool calls.")
|
|
# goto will be set in the final section based on tool calls
|
|
|
|
# ============================================================
|
|
# Final: Build and return Command
|
|
# ============================================================
|
|
messages = list(state.get("messages", []) or [])
|
|
if response.content:
|
|
messages.append(HumanMessage(content=response.content, name="coordinator"))
|
|
|
|
# Process tool calls for BOTH branches (legacy and clarification)
|
|
if response.tool_calls:
|
|
try:
|
|
for tool_call in response.tool_calls:
|
|
tool_name = tool_call.get("name", "")
|
|
tool_args = tool_call.get("args", {})
|
|
|
|
if tool_name in ["handoff_to_planner", "handoff_after_clarification"]:
|
|
logger.info("Handing off to planner")
|
|
goto = "planner"
|
|
|
|
if not enable_clarification and tool_args.get("research_topic"):
|
|
research_topic = tool_args["research_topic"]
|
|
|
|
if enable_clarification:
|
|
logger.info(
|
|
"Using prepared clarified topic: %s",
|
|
clarified_topic or research_topic,
|
|
)
|
|
else:
|
|
logger.info(
|
|
"Using research topic for handoff: %s", research_topic
|
|
)
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing tool calls: {e}")
|
|
goto = "planner"
|
|
else:
|
|
# No tool calls detected
|
|
if enable_clarification:
|
|
# BRANCH 2: Fallback to planner to ensure research proceeds
|
|
logger.warning(
|
|
"LLM didn't call any tools. This may indicate tool calling issues with the model. "
|
|
"Falling back to planner to ensure research proceeds."
|
|
)
|
|
logger.debug(f"Coordinator response content: {response.content}")
|
|
logger.debug(f"Coordinator response object: {response}")
|
|
goto = "planner"
|
|
else:
|
|
# BRANCH 1: No tool calls means end workflow gracefully (e.g., greeting handled)
|
|
logger.info("No tool calls in legacy mode - ending workflow gracefully")
|
|
|
|
# Apply background_investigation routing if enabled (unified logic)
|
|
if goto == "planner" and state.get("enable_background_investigation"):
|
|
goto = "background_investigator"
|
|
|
|
# Set default values for state variables (in case they're not defined in legacy mode)
|
|
if not enable_clarification:
|
|
clarification_rounds = 0
|
|
clarification_history = []
|
|
|
|
clarified_research_topic_value = clarified_topic or research_topic
|
|
|
|
# clarified_research_topic: Complete clarified topic with all clarification rounds
|
|
return Command(
|
|
update={
|
|
"messages": messages,
|
|
"locale": locale,
|
|
"research_topic": research_topic,
|
|
"clarified_research_topic": clarified_research_topic_value,
|
|
"resources": configurable.resources,
|
|
"clarification_rounds": clarification_rounds,
|
|
"clarification_history": clarification_history,
|
|
"is_clarification_complete": goto != "coordinator",
|
|
"goto": goto,
|
|
"citations": state.get("citations", []),
|
|
},
|
|
goto=goto,
|
|
)
|
|
|
|
|
|
def reporter_node(state: State, config: RunnableConfig):
|
|
"""Reporter node that write a final report."""
|
|
logger.info("Reporter write final report")
|
|
configurable = Configuration.from_runnable_config(config)
|
|
current_plan = state.get("current_plan")
|
|
input_ = {
|
|
"messages": [
|
|
HumanMessage(
|
|
f"# Research Requirements\n\n## Task\n\n{current_plan.title}\n\n## Description\n\n{current_plan.thought}"
|
|
)
|
|
],
|
|
"locale": state.get("locale", "en-US"),
|
|
}
|
|
invoke_messages = apply_prompt_template("reporter", input_, configurable, input_.get("locale", "en-US"))
|
|
observations = state.get("observations", [])
|
|
|
|
# Get collected citations for the report
|
|
citations = state.get("citations", [])
|
|
|
|
# If we have collected citations, provide them to the reporter
|
|
if citations:
|
|
citation_list = "\n\n## Available Source References (use these in References section):\n\n"
|
|
for i, citation in enumerate(citations, 1):
|
|
title = citation.get("title", "Untitled")
|
|
url = citation.get("url", "")
|
|
domain = citation.get("domain", "")
|
|
description = citation.get("description", "")
|
|
desc_truncated = description[:150] if description else ""
|
|
citation_list += f"{i}. **{title}**\n - URL: {url}\n - Domain: {domain}\n"
|
|
if desc_truncated:
|
|
citation_list += f" - Summary: {desc_truncated}...\n"
|
|
citation_list += "\n"
|
|
|
|
logger.info(f"Providing {len(citations)} collected citations to reporter")
|
|
|
|
invoke_messages.append(
|
|
HumanMessage(
|
|
content=citation_list,
|
|
name="system",
|
|
)
|
|
)
|
|
|
|
observation_messages = []
|
|
for observation in observations:
|
|
observation_messages.append(
|
|
HumanMessage(
|
|
content=f"Below are some observations for the research task:\n\n{observation}",
|
|
name="observation",
|
|
)
|
|
)
|
|
|
|
# Context compression
|
|
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP["reporter"])
|
|
compressed_state = ContextManager(llm_token_limit).compress_messages(
|
|
{"messages": observation_messages}
|
|
)
|
|
invoke_messages += compressed_state.get("messages", [])
|
|
|
|
logger.debug(f"Current invoke messages: {invoke_messages}")
|
|
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
|
|
response_content = response.content
|
|
logger.info(f"reporter response: {response_content}")
|
|
|
|
return {
|
|
"final_report": response_content,
|
|
"citations": citations, # Pass citations through to final state
|
|
}
|
|
|
|
|
|
def research_team_node(state: State):
|
|
"""Research team node that collaborates on tasks."""
|
|
logger.info("Research team is collaborating on tasks.")
|
|
logger.debug("Entering research_team_node - coordinating research and coder agents")
|
|
pass
|
|
|
|
|
|
def validate_web_search_usage(messages: list, agent_name: str = "agent") -> bool:
|
|
"""
|
|
Validate if the agent has used the web search tool during execution.
|
|
|
|
Args:
|
|
messages: List of messages from the agent execution
|
|
agent_name: Name of the agent (for logging purposes)
|
|
|
|
Returns:
|
|
bool: True if web search tool was used, False otherwise
|
|
"""
|
|
web_search_used = False
|
|
|
|
for message in messages:
|
|
# Check for ToolMessage instances indicating web search was used
|
|
if isinstance(message, ToolMessage) and message.name == "web_search":
|
|
web_search_used = True
|
|
logger.info(f"[VALIDATION] {agent_name} received ToolMessage from web_search tool")
|
|
break
|
|
|
|
# Check for AIMessage content that mentions tool calls
|
|
if hasattr(message, 'tool_calls') and message.tool_calls:
|
|
for tool_call in message.tool_calls:
|
|
if tool_call.get('name') == "web_search":
|
|
web_search_used = True
|
|
logger.info(f"[VALIDATION] {agent_name} called web_search tool")
|
|
break
|
|
# break outer loop if web search was used
|
|
if web_search_used:
|
|
break
|
|
|
|
# Check for message name attribute
|
|
if hasattr(message, 'name') and message.name == "web_search":
|
|
web_search_used = True
|
|
logger.info(f"[VALIDATION] {agent_name} used web_search tool")
|
|
break
|
|
|
|
if not web_search_used:
|
|
logger.warning(f"[VALIDATION] {agent_name} did not use web_search tool")
|
|
|
|
return web_search_used
|
|
|
|
|
|
async def _handle_recursion_limit_fallback(
|
|
messages: list,
|
|
agent_name: str,
|
|
current_step,
|
|
state: State,
|
|
) -> list:
|
|
"""Handle GraphRecursionError with graceful fallback using LLM summary.
|
|
|
|
When the agent hits the recursion limit, this function generates a final output
|
|
using only the observations already gathered, without calling any tools.
|
|
|
|
Args:
|
|
messages: Messages accumulated during agent execution before hitting limit
|
|
agent_name: Name of the agent that hit the limit
|
|
current_step: The current step being executed
|
|
state: Current workflow state
|
|
|
|
Returns:
|
|
list: Messages including the accumulated messages plus the fallback summary
|
|
|
|
Raises:
|
|
Exception: If the fallback LLM call fails
|
|
"""
|
|
logger.warning(
|
|
f"Recursion limit reached for {agent_name} agent. "
|
|
f"Attempting graceful fallback with {len(messages)} accumulated messages."
|
|
)
|
|
|
|
if len(messages) == 0:
|
|
return messages
|
|
|
|
cleared_messages = messages.copy()
|
|
while len(cleared_messages) > 0 and cleared_messages[-1].type == "system":
|
|
cleared_messages = cleared_messages[:-1]
|
|
|
|
# Prepare state for prompt template
|
|
fallback_state = {
|
|
"locale": state.get("locale", "en-US"),
|
|
}
|
|
|
|
# Apply the recursion_fallback prompt template
|
|
system_prompt = get_system_prompt_template(agent_name, fallback_state, None, fallback_state.get("locale", "en-US"))
|
|
limit_prompt = get_system_prompt_template("recursion_fallback", fallback_state, None, fallback_state.get("locale", "en-US"))
|
|
fallback_messages = cleared_messages + [
|
|
SystemMessage(content=system_prompt),
|
|
SystemMessage(content=limit_prompt)
|
|
]
|
|
|
|
# Get the LLM without tools (strip all tools from binding)
|
|
fallback_llm = get_llm_by_type(AGENT_LLM_MAP[agent_name])
|
|
|
|
# Call the LLM with the updated messages
|
|
fallback_response = fallback_llm.invoke(fallback_messages)
|
|
fallback_content = fallback_response.content
|
|
|
|
logger.info(
|
|
f"Graceful fallback succeeded for {agent_name} agent. "
|
|
f"Generated summary of {len(fallback_content)} characters."
|
|
)
|
|
|
|
# Sanitize response
|
|
fallback_content = sanitize_tool_response(str(fallback_content))
|
|
|
|
# Update the step with the fallback result
|
|
current_step.execution_res = fallback_content
|
|
|
|
# Return the accumulated messages plus the fallback response
|
|
result_messages = list(cleared_messages)
|
|
result_messages.append(AIMessage(content=fallback_content, name=agent_name))
|
|
|
|
return result_messages
|
|
|
|
|
|
async def _execute_agent_step(
|
|
state: State, agent, agent_name: str, config: RunnableConfig = None
|
|
) -> Command[Literal["research_team"]]:
|
|
"""Helper function to execute a step using the specified agent."""
|
|
logger.debug(f"[_execute_agent_step] Starting execution for agent: {agent_name}")
|
|
|
|
current_plan = state.get("current_plan")
|
|
plan_title = current_plan.title
|
|
observations = state.get("observations", [])
|
|
logger.debug(f"[_execute_agent_step] Plan title: {plan_title}, observations count: {len(observations)}")
|
|
|
|
# Find the first unexecuted step
|
|
current_step = None
|
|
completed_steps = []
|
|
for idx, step in enumerate(current_plan.steps):
|
|
if not step.execution_res:
|
|
current_step = step
|
|
logger.debug(f"[_execute_agent_step] Found unexecuted step at index {idx}: {step.title}")
|
|
break
|
|
else:
|
|
completed_steps.append(step)
|
|
|
|
if not current_step:
|
|
logger.warning(f"[_execute_agent_step] No unexecuted step found in {len(current_plan.steps)} total steps")
|
|
return Command(
|
|
update=preserve_state_meta_fields(state),
|
|
goto="research_team"
|
|
)
|
|
|
|
logger.info(f"[_execute_agent_step] Executing step: {current_step.title}, agent: {agent_name}")
|
|
logger.debug(f"[_execute_agent_step] Completed steps so far: {len(completed_steps)}")
|
|
|
|
# Format completed steps information
|
|
completed_steps_info = ""
|
|
if completed_steps:
|
|
completed_steps_info = "# Completed Research Steps\n\n"
|
|
for i, step in enumerate(completed_steps):
|
|
completed_steps_info += f"## Completed Step {i + 1}: {step.title}\n\n"
|
|
completed_steps_info += f"<finding>\n{step.execution_res}\n</finding>\n\n"
|
|
|
|
# Prepare the input for the agent with completed steps info
|
|
agent_input = {
|
|
"messages": [
|
|
HumanMessage(
|
|
content=f"# Research Topic\n\n{plan_title}\n\n{completed_steps_info}# Current Step\n\n## Title\n\n{current_step.title}\n\n## Description\n\n{current_step.description}\n\n## Locale\n\n{state.get('locale', 'en-US')}"
|
|
)
|
|
]
|
|
}
|
|
|
|
# Add citation reminder for researcher agent
|
|
if agent_name == "researcher":
|
|
if state.get("resources"):
|
|
resources_info = "**The user mentioned the following resource files:**\n\n"
|
|
for resource in state.get("resources"):
|
|
resources_info += f"- {resource.title} ({resource.description})\n"
|
|
|
|
agent_input["messages"].append(
|
|
HumanMessage(
|
|
content=resources_info
|
|
+ "\n\n"
|
|
+ "You MUST use the **local_search_tool** to retrieve the information from the resource files.",
|
|
)
|
|
)
|
|
|
|
agent_input["messages"].append(
|
|
HumanMessage(
|
|
content="IMPORTANT: DO NOT include inline citations in the text. Instead, track all sources and include a References section at the end using link reference format. Include an empty line between each citation for better readability. Use this format for each reference:\n- [Source Title](URL)\n\n- [Another Source](URL)",
|
|
name="system",
|
|
)
|
|
)
|
|
|
|
# Invoke the agent
|
|
default_recursion_limit = 25
|
|
try:
|
|
env_value_str = os.getenv("AGENT_RECURSION_LIMIT", str(default_recursion_limit))
|
|
parsed_limit = int(env_value_str)
|
|
|
|
if parsed_limit > 0:
|
|
recursion_limit = parsed_limit
|
|
logger.info(f"Recursion limit set to: {recursion_limit}")
|
|
else:
|
|
logger.warning(
|
|
f"AGENT_RECURSION_LIMIT value '{env_value_str}' (parsed as {parsed_limit}) is not positive. "
|
|
f"Using default value {default_recursion_limit}."
|
|
)
|
|
recursion_limit = default_recursion_limit
|
|
except ValueError:
|
|
raw_env_value = os.getenv("AGENT_RECURSION_LIMIT")
|
|
logger.warning(
|
|
f"Invalid AGENT_RECURSION_LIMIT value: '{raw_env_value}'. "
|
|
f"Using default value {default_recursion_limit}."
|
|
)
|
|
recursion_limit = default_recursion_limit
|
|
|
|
logger.info(f"Agent input: {agent_input}")
|
|
|
|
# Validate message content before invoking agent
|
|
try:
|
|
validated_messages = validate_message_content(agent_input["messages"])
|
|
agent_input["messages"] = validated_messages
|
|
except Exception as validation_error:
|
|
logger.error(f"Error validating agent input messages: {validation_error}")
|
|
|
|
# Apply context compression to prevent token overflow (Issue #721)
|
|
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_name])
|
|
if llm_token_limit:
|
|
token_count_before = sum(
|
|
len(str(msg.content).split()) for msg in agent_input.get("messages", []) if hasattr(msg, "content")
|
|
)
|
|
compressed_state = ContextManager(llm_token_limit, preserve_prefix_message_count=3).compress_messages(
|
|
{"messages": agent_input["messages"]}
|
|
)
|
|
agent_input["messages"] = compressed_state.get("messages", [])
|
|
token_count_after = sum(
|
|
len(str(msg.content).split()) for msg in agent_input.get("messages", []) if hasattr(msg, "content")
|
|
)
|
|
logger.info(
|
|
f"Context compression for {agent_name}: {len(compressed_state.get('messages', []))} messages, "
|
|
f"estimated tokens before: ~{token_count_before}, after: ~{token_count_after}"
|
|
)
|
|
|
|
try:
|
|
# Use astream (async) from the start to capture messages in real-time
|
|
# This allows us to retrieve accumulated messages even if recursion limit is hit
|
|
# NOTE: astream is required for MCP tools which only support async invocation
|
|
accumulated_messages = []
|
|
async for chunk in agent.astream(
|
|
input=agent_input,
|
|
config={"recursion_limit": recursion_limit},
|
|
stream_mode="values",
|
|
):
|
|
if isinstance(chunk, dict) and "messages" in chunk:
|
|
accumulated_messages = chunk["messages"]
|
|
|
|
# If we get here, execution completed successfully
|
|
result = {"messages": accumulated_messages}
|
|
except GraphRecursionError:
|
|
# Check if recursion fallback is enabled
|
|
configurable = Configuration.from_runnable_config(config) if config else Configuration()
|
|
|
|
if configurable.enable_recursion_fallback:
|
|
try:
|
|
# Call fallback with accumulated messages (function returns list of messages)
|
|
response_messages = await _handle_recursion_limit_fallback(
|
|
messages=accumulated_messages,
|
|
agent_name=agent_name,
|
|
current_step=current_step,
|
|
state=state,
|
|
)
|
|
|
|
# Create result dict so the code can continue normally from line 1178
|
|
result = {"messages": response_messages}
|
|
except Exception as fallback_error:
|
|
# If fallback fails, log and fall through to standard error handling
|
|
logger.error(
|
|
f"Recursion fallback failed for {agent_name} agent: {fallback_error}. "
|
|
"Falling back to standard error handling."
|
|
)
|
|
raise
|
|
else:
|
|
# Fallback disabled, let error propagate to standard handler
|
|
logger.info(
|
|
f"Recursion limit reached but graceful fallback is disabled. "
|
|
"Using standard error handling."
|
|
)
|
|
raise
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
error_traceback = traceback.format_exc()
|
|
error_message = f"Error executing {agent_name} agent for step '{current_step.title}': {str(e)}"
|
|
logger.exception(error_message)
|
|
logger.error(f"Full traceback:\n{error_traceback}")
|
|
|
|
# Enhanced error diagnostics for content-related errors
|
|
if "Field required" in str(e) and "content" in str(e):
|
|
logger.error(f"Message content validation error detected")
|
|
for i, msg in enumerate(agent_input.get('messages', [])):
|
|
logger.error(f"Message {i}: type={type(msg).__name__}, "
|
|
f"has_content={hasattr(msg, 'content')}, "
|
|
f"content_type={type(msg.content).__name__ if hasattr(msg, 'content') else 'N/A'}, "
|
|
f"content_len={len(str(msg.content)) if hasattr(msg, 'content') and msg.content else 0}")
|
|
|
|
detailed_error = f"[ERROR] {agent_name.capitalize()} Agent Error\n\nStep: {current_step.title}\n\nError Details:\n{str(e)}\n\nPlease check the logs for more information."
|
|
current_step.execution_res = detailed_error
|
|
|
|
return Command(
|
|
update={
|
|
"messages": [
|
|
HumanMessage(
|
|
content=detailed_error,
|
|
name=agent_name,
|
|
)
|
|
],
|
|
"observations": observations + [detailed_error],
|
|
**preserve_state_meta_fields(state),
|
|
},
|
|
goto="research_team",
|
|
)
|
|
|
|
response_messages = result["messages"]
|
|
|
|
# Process the result
|
|
response_content = response_messages[-1].content
|
|
|
|
# Sanitize response to remove extra tokens and truncate if needed
|
|
response_content = sanitize_tool_response(str(response_content))
|
|
|
|
logger.debug(f"{agent_name.capitalize()} full response: {response_content}")
|
|
|
|
# Validate web search usage for researcher agent if enforcement is enabled
|
|
web_search_validated = True
|
|
should_validate = agent_name == "researcher"
|
|
validation_info = ""
|
|
|
|
if should_validate:
|
|
# Check if enforcement is enabled in configuration
|
|
configurable = Configuration.from_runnable_config(config) if config else Configuration()
|
|
# Skip validation if web search is disabled (user intentionally disabled it)
|
|
if configurable.enforce_researcher_search and configurable.enable_web_search:
|
|
web_search_validated = validate_web_search_usage(result["messages"], agent_name)
|
|
|
|
# If web search was not used, add a warning to the response
|
|
if not web_search_validated:
|
|
logger.warning(f"[VALIDATION] Researcher did not use web_search tool. Adding reminder to response.")
|
|
# Add validation information to observations
|
|
validation_info = (
|
|
"\n\n[WARNING] This research was completed without using the web_search tool. "
|
|
"Please verify that the information provided is accurate and up-to-date."
|
|
"\n\n[VALIDATION WARNING] Researcher did not use the web_search tool as recommended."
|
|
)
|
|
|
|
# Update the step with the execution result
|
|
current_step.execution_res = response_content
|
|
logger.info(f"Step '{current_step.title}' execution completed by {agent_name}")
|
|
|
|
# Include all messages from agent result to preserve intermediate tool calls/results
|
|
# This ensures multiple web_search calls all appear in the stream, not just the final result
|
|
agent_messages = result.get("messages", [])
|
|
logger.debug(
|
|
f"{agent_name.capitalize()} returned {len(agent_messages)} messages. "
|
|
f"Message types: {[type(msg).__name__ for msg in agent_messages]}"
|
|
)
|
|
|
|
# Count tool messages for logging
|
|
tool_message_count = sum(1 for msg in agent_messages if isinstance(msg, ToolMessage))
|
|
if tool_message_count > 0:
|
|
logger.info(
|
|
f"{agent_name.capitalize()} agent made {tool_message_count} tool calls. "
|
|
f"All tool results will be preserved and streamed to frontend."
|
|
)
|
|
|
|
# Extract citations from tool call results (web_search, crawl)
|
|
existing_citations = state.get("citations", [])
|
|
new_citations = extract_citations_from_messages(agent_messages)
|
|
merged_citations = merge_citations(existing_citations, new_citations)
|
|
|
|
if new_citations:
|
|
logger.info(
|
|
f"Extracted {len(new_citations)} new citations from {agent_name} agent. "
|
|
f"Total citations: {len(merged_citations)}"
|
|
)
|
|
|
|
return Command(
|
|
update={
|
|
**preserve_state_meta_fields(state),
|
|
"messages": agent_messages,
|
|
"observations": observations + [response_content + validation_info],
|
|
"citations": merged_citations, # Store merged citations based on existing state and new tool results
|
|
},
|
|
goto="research_team",
|
|
)
|
|
|
|
|
|
async def _setup_and_execute_agent_step(
|
|
state: State,
|
|
config: RunnableConfig,
|
|
agent_type: str,
|
|
default_tools: list,
|
|
) -> Command[Literal["research_team"]]:
|
|
"""Helper function to set up an agent with appropriate tools and execute a step.
|
|
|
|
This function handles the common logic for both researcher_node and coder_node:
|
|
1. Configures MCP servers and tools based on agent type
|
|
2. Creates an agent with the appropriate tools or uses the default agent
|
|
3. Executes the agent on the current step
|
|
|
|
Args:
|
|
state: The current state
|
|
config: The runnable config
|
|
agent_type: The type of agent ("researcher" or "coder")
|
|
default_tools: The default tools to add to the agent
|
|
|
|
Returns:
|
|
Command to update state and go to research_team
|
|
"""
|
|
configurable = Configuration.from_runnable_config(config)
|
|
mcp_servers = {}
|
|
enabled_tools = {}
|
|
loaded_tools = default_tools[:]
|
|
|
|
# Get locale from workflow state to pass to agent creation
|
|
# This fixes issue #743 where locale was not correctly retrieved in agent prompt
|
|
locale = state.get("locale", "en-US")
|
|
|
|
# Extract MCP server configuration for this agent type
|
|
if configurable.mcp_settings:
|
|
for server_name, server_config in configurable.mcp_settings["servers"].items():
|
|
if (
|
|
server_config["enabled_tools"]
|
|
and agent_type in server_config["add_to_agents"]
|
|
):
|
|
mcp_servers[server_name] = {
|
|
k: v
|
|
for k, v in server_config.items()
|
|
if k in ("transport", "command", "args", "url", "env", "headers")
|
|
}
|
|
for tool_name in server_config["enabled_tools"]:
|
|
enabled_tools[tool_name] = server_name
|
|
|
|
# Create and execute agent with MCP tools if available
|
|
if mcp_servers:
|
|
# Add MCP tools to loaded tools if MCP servers are configured
|
|
client = MultiServerMCPClient(mcp_servers)
|
|
all_tools = await client.get_tools()
|
|
for tool in all_tools:
|
|
if tool.name in enabled_tools:
|
|
tool.description = (
|
|
f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}"
|
|
)
|
|
loaded_tools.append(tool)
|
|
|
|
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
|
|
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
|
|
agent = create_agent(
|
|
agent_type,
|
|
agent_type,
|
|
loaded_tools,
|
|
agent_type,
|
|
pre_model_hook,
|
|
interrupt_before_tools=configurable.interrupt_before_tools,
|
|
locale=locale,
|
|
)
|
|
return await _execute_agent_step(state, agent, agent_type, config)
|
|
|
|
|
|
async def researcher_node(
|
|
state: State, config: RunnableConfig
|
|
) -> Command[Literal["research_team"]]:
|
|
"""Researcher node that do research"""
|
|
logger.info("Researcher node is researching.")
|
|
logger.debug(f"[researcher_node] Starting researcher agent")
|
|
|
|
configurable = Configuration.from_runnable_config(config)
|
|
logger.debug(f"[researcher_node] Max search results: {configurable.max_search_results}")
|
|
|
|
# Build tools list based on configuration
|
|
tools = []
|
|
|
|
# Add web search and crawl tools only if web search is enabled
|
|
if configurable.enable_web_search:
|
|
tools.extend([get_web_search_tool(configurable.max_search_results), crawl_tool])
|
|
else:
|
|
logger.info("[researcher_node] Web search is disabled, using only local RAG")
|
|
|
|
# Add retriever tool if resources are available (always add, higher priority)
|
|
retriever_tool = get_retriever_tool(state.get("resources", []))
|
|
if retriever_tool:
|
|
logger.debug(f"[researcher_node] Adding retriever tool to tools list")
|
|
tools.insert(0, retriever_tool)
|
|
|
|
# Warn if no tools are available
|
|
if not tools:
|
|
logger.warning("[researcher_node] No tools available (web search disabled, no resources). "
|
|
"Researcher will operate in pure reasoning mode.")
|
|
|
|
logger.info(f"[researcher_node] Researcher tools count: {len(tools)}")
|
|
logger.debug(f"[researcher_node] Researcher tools: {[tool.name if hasattr(tool, 'name') else str(tool) for tool in tools]}")
|
|
logger.info(f"[researcher_node] enforce_researcher_search={configurable.enforce_researcher_search}, "
|
|
f"enable_web_search={configurable.enable_web_search}")
|
|
|
|
return await _setup_and_execute_agent_step(
|
|
state,
|
|
config,
|
|
"researcher",
|
|
tools,
|
|
)
|
|
|
|
|
|
async def coder_node(
|
|
state: State, config: RunnableConfig
|
|
) -> Command[Literal["research_team"]]:
|
|
"""Coder node that do code analysis."""
|
|
logger.info("Coder node is coding.")
|
|
logger.debug(f"[coder_node] Starting coder agent with python_repl_tool")
|
|
|
|
return await _setup_and_execute_agent_step(
|
|
state,
|
|
config,
|
|
"coder",
|
|
[python_repl_tool],
|
|
)
|
|
|
|
|
|
async def analyst_node(
|
|
state: State, config: RunnableConfig
|
|
) -> Command[Literal["research_team"]]:
|
|
"""Analyst node that performs reasoning and analysis without code execution.
|
|
|
|
This node handles tasks like:
|
|
- Cross-validating information from multiple sources
|
|
- Synthesizing research findings
|
|
- Comparative analysis
|
|
- Pattern recognition and trend analysis
|
|
- General reasoning tasks that don't require code
|
|
"""
|
|
logger.info("Analyst node is analyzing.")
|
|
logger.debug(f"[analyst_node] Starting analyst agent for reasoning/analysis tasks")
|
|
|
|
# Analyst uses no tools - pure LLM reasoning
|
|
return await _setup_and_execute_agent_step(
|
|
state,
|
|
config,
|
|
"analyst",
|
|
[], # No tools - pure reasoning
|
|
)
|