feat(nodes): add background investigation node

Change-Id: I96e08e22fc7c52647edbf9be4f385a8fae9b449a
This commit is contained in:
Zhao Longjie
2025-04-27 20:15:42 +08:00
parent ada5e34eeb
commit 899438eca0
7 changed files with 90 additions and 8 deletions

View File

@@ -13,6 +13,7 @@ from .nodes import (
researcher_node,
coder_node,
human_feedback_node,
background_investigation_node,
)
@@ -21,6 +22,7 @@ def _build_base_graph():
builder = StateGraph(State)
builder.add_edge(START, "coordinator")
builder.add_node("coordinator", coordinator_node)
builder.add_node("background_investigator", background_investigation_node)
builder.add_node("planner", planner_node)
builder.add_node("reporter", reporter_node)
builder.add_node("research_team", research_team_node)

View File

@@ -13,6 +13,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient
from src.agents.agents import coder_agent, research_agent, create_agent
from src.tools.search import LoggedTavilySearch
from src.tools import (
crawl_tool,
web_search_tool,
@@ -27,6 +28,7 @@ from src.prompts.template import apply_prompt_template
from src.utils.json_utils import repair_json_output
from .types import State
from ..config import SEARCH_MAX_RESULTS
logger = logging.getLogger(__name__)
@@ -42,13 +44,55 @@ def handoff_to_planner(
return
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
logger.info("background investigation node is running.")
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
{"query": state["messages"][-1].content}
)
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
{"title": elem["title"], "content": elem["content"]}
for elem in searched_content
]
else:
logger.error(f"Tavily search returned malformed response: {searched_content}")
return Command(
update={
"background_investigation_results": json.dumps(
background_investigation_results, ensure_ascii=False
)
},
goto="planner",
)
def planner_node(
state: State, config: RunnableConfig
) -> Command[Literal["human_feedback", "reporter"]]:
"""Planner node that generate the full plan."""
logger.info("Planner generating full plan")
configurable = Configuration.from_runnable_config(config)
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
messages = apply_prompt_template("planner", state, configurable)
if (
plan_iterations == 0
and state.get("enable_background_investigation")
and state.get("background_investigation_results")
):
messages += [
{
"role": "user",
"content": (
"background investigation results of user query:\n"
+ state["background_investigation_results"]
+ "\n"
),
}
]
if AGENT_LLM_MAP["planner"] == "basic":
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
Plan,
@@ -56,7 +100,6 @@ def planner_node(
)
else:
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
# if the plan iterations is greater than the max plan iterations, return the reporter node
if plan_iterations >= configurable.max_plan_iterations:
@@ -134,7 +177,9 @@ def human_feedback_node(
)
def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
def coordinator_node(
state: State,
) -> Command[Literal["planner", "background_investigator", "__end__"]]:
"""Coordinator node that communicate with customers."""
logger.info("Coordinator talking.")
messages = apply_prompt_template("coordinator", state)
@@ -150,6 +195,9 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
if len(response.tool_calls) > 0:
goto = "planner"
if state.get("enable_background_investigation"):
# if the search_before_planning is True, add the web search tool to the planner agent
goto = "background_investigator"
try:
for tool_call in response.tool_calls:
if tool_call.get("name", "") != "handoff_to_planner":
@@ -160,9 +208,7 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
except Exception as e:
logger.error(f"Error processing tool calls: {e}")
return Command(
update={
"locale": locale
},
update={"locale": locale},
goto=goto,
)

View File

@@ -19,3 +19,5 @@ class State(MessagesState):
current_plan: Plan | str = None
final_report: str = ""
auto_accepted_plan: bool = False
enable_background_investigation: bool = True
background_investigation_results: str = None