From 899438eca08875bcd3951dabd7144ddcfe0c3afb Mon Sep 17 00:00:00 2001 From: Zhao Longjie Date: Sun, 27 Apr 2025 20:15:42 +0800 Subject: [PATCH] feat(nodes): add background investigation node Change-Id: I96e08e22fc7c52647edbf9be4f385a8fae9b449a --- main.py | 27 ++++++++++++++++-- src/graph/builder.py | 2 ++ src/graph/nodes.py | 56 ++++++++++++++++++++++++++++++++++---- src/graph/types.py | 2 ++ src/server/app.py | 5 +++- src/server/chat_request.py | 3 ++ src/workflow.py | 3 ++ 7 files changed, 90 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index dc00c48..38a7460 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,13 @@ from src.workflow import run_agent_workflow_async from src.config.questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN -def ask(question, debug=False, max_plan_iterations=1, max_step_num=3): +def ask( + question, + debug=False, + max_plan_iterations=1, + max_step_num=3, + enable_background_investigation=True, +): """Run the agent workflow with the given question. Args: @@ -21,6 +27,7 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3): debug: If True, enables debug level logging max_plan_iterations: Maximum number of plan iterations max_step_num: Maximum number of steps in a plan + enable_background_investigation: If True, performs web search before planning to enhance context """ asyncio.run( run_agent_workflow_async( @@ -28,14 +35,21 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3): debug=debug, max_plan_iterations=max_plan_iterations, max_step_num=max_step_num, + enable_background_investigation=enable_background_investigation, ) ) -def main(debug=False, max_plan_iterations=1, max_step_num=3): +def main( + debug=False, + max_plan_iterations=1, + max_step_num=3, + enable_background_investigation=True, +): """Interactive mode with built-in questions. Args: + enable_background_investigation: If True, performs web search before planning to enhance context debug: If True, enables debug level logging max_plan_iterations: Maximum number of plan iterations max_step_num: Maximum number of steps in a plan @@ -77,6 +91,7 @@ def main(debug=False, max_plan_iterations=1, max_step_num=3): debug=debug, max_plan_iterations=max_plan_iterations, max_step_num=max_step_num, + enable_background_investigation=enable_background_investigation, ) @@ -102,6 +117,12 @@ if __name__ == "__main__": help="Maximum number of steps in a plan (default: 3)", ) parser.add_argument("--debug", action="store_true", help="Enable debug logging") + parser.add_argument( + "--no-background-investigation", + action="store_false", + dest="enable_background_investigation", + help="Disable background investigation before planning", + ) args = parser.parse_args() @@ -111,6 +132,7 @@ if __name__ == "__main__": debug=args.debug, max_plan_iterations=args.max_plan_iterations, max_step_num=args.max_step_num, + enable_background_investigation=args.enable_background_investigation, ) else: # Parse user input from command line arguments or user input @@ -125,4 +147,5 @@ if __name__ == "__main__": debug=args.debug, max_plan_iterations=args.max_plan_iterations, max_step_num=args.max_step_num, + enable_background_investigation=args.enable_background_investigation, ) diff --git a/src/graph/builder.py b/src/graph/builder.py index 3aebc89..dc939fa 100644 --- a/src/graph/builder.py +++ b/src/graph/builder.py @@ -13,6 +13,7 @@ from .nodes import ( researcher_node, coder_node, human_feedback_node, + background_investigation_node, ) @@ -21,6 +22,7 @@ def _build_base_graph(): builder = StateGraph(State) builder.add_edge(START, "coordinator") builder.add_node("coordinator", coordinator_node) + builder.add_node("background_investigator", background_investigation_node) builder.add_node("planner", planner_node) builder.add_node("reporter", reporter_node) builder.add_node("research_team", research_team_node) diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 0ffa9a8..e5cc896 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -13,6 +13,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient from src.agents.agents import coder_agent, research_agent, create_agent +from src.tools.search import LoggedTavilySearch from src.tools import ( crawl_tool, web_search_tool, @@ -27,6 +28,7 @@ from src.prompts.template import apply_prompt_template from src.utils.json_utils import repair_json_output from .types import State +from ..config import SEARCH_MAX_RESULTS logger = logging.getLogger(__name__) @@ -42,13 +44,55 @@ def handoff_to_planner( return +def background_investigation_node(state: State) -> Command[Literal["planner"]]: + + logger.info("background investigation node is running.") + searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke( + {"query": state["messages"][-1].content} + ) + background_investigation_results = None + if isinstance(searched_content, list): + background_investigation_results = [ + {"title": elem["title"], "content": elem["content"]} + for elem in searched_content + ] + else: + logger.error(f"Tavily search returned malformed response: {searched_content}") + return Command( + update={ + "background_investigation_results": json.dumps( + background_investigation_results, ensure_ascii=False + ) + }, + goto="planner", + ) + + def planner_node( state: State, config: RunnableConfig ) -> Command[Literal["human_feedback", "reporter"]]: """Planner node that generate the full plan.""" logger.info("Planner generating full plan") configurable = Configuration.from_runnable_config(config) + plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 messages = apply_prompt_template("planner", state, configurable) + + if ( + plan_iterations == 0 + and state.get("enable_background_investigation") + and state.get("background_investigation_results") + ): + messages += [ + { + "role": "user", + "content": ( + "background investigation results of user query:\n" + + state["background_investigation_results"] + + "\n" + ), + } + ] + if AGENT_LLM_MAP["planner"] == "basic": llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output( Plan, @@ -56,7 +100,6 @@ def planner_node( ) else: llm = get_llm_by_type(AGENT_LLM_MAP["planner"]) - plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 # if the plan iterations is greater than the max plan iterations, return the reporter node if plan_iterations >= configurable.max_plan_iterations: @@ -134,7 +177,9 @@ def human_feedback_node( ) -def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]: +def coordinator_node( + state: State, +) -> Command[Literal["planner", "background_investigator", "__end__"]]: """Coordinator node that communicate with customers.""" logger.info("Coordinator talking.") messages = apply_prompt_template("coordinator", state) @@ -150,6 +195,9 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]: if len(response.tool_calls) > 0: goto = "planner" + if state.get("enable_background_investigation"): + # if the search_before_planning is True, add the web search tool to the planner agent + goto = "background_investigator" try: for tool_call in response.tool_calls: if tool_call.get("name", "") != "handoff_to_planner": @@ -160,9 +208,7 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]: except Exception as e: logger.error(f"Error processing tool calls: {e}") return Command( - update={ - "locale": locale - }, + update={"locale": locale}, goto=goto, ) diff --git a/src/graph/types.py b/src/graph/types.py index ccbbedf..5ba9cf7 100644 --- a/src/graph/types.py +++ b/src/graph/types.py @@ -19,3 +19,5 @@ class State(MessagesState): current_plan: Plan | str = None final_report: str = "" auto_accepted_plan: bool = False + enable_background_investigation: bool = True + background_investigation_results: str = None diff --git a/src/server/app.py b/src/server/app.py index 9b4b0bd..ea2156a 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -63,6 +63,7 @@ async def chat_stream(request: ChatRequest): request.auto_accepted_plan, request.interrupt_feedback, request.mcp_settings, + request.enable_background_investigation, ), media_type="text/event-stream", ) @@ -76,6 +77,7 @@ async def _astream_workflow_generator( auto_accepted_plan: bool, interrupt_feedback: str, mcp_settings: dict, + enable_background_investigation, ): input_ = { "messages": messages, @@ -84,12 +86,13 @@ async def _astream_workflow_generator( "current_plan": None, "observations": [], "auto_accepted_plan": auto_accepted_plan, + "enable_background_investigation": enable_background_investigation, } if not auto_accepted_plan and interrupt_feedback: resume_msg = f"[{interrupt_feedback}]" # add the last message to the resume message if messages: - resume_msg += f" {messages[-1]["content"]}" + resume_msg += f" {messages[-1]['content']}" input_ = Command(resume=resume_msg) async for agent, _, event_data in graph.astream( input_, diff --git a/src/server/chat_request.py b/src/server/chat_request.py index c8850f5..e2abcde 100644 --- a/src/server/chat_request.py +++ b/src/server/chat_request.py @@ -47,6 +47,9 @@ class ChatRequest(BaseModel): mcp_settings: Optional[dict] = Field( None, description="MCP settings for the chat request" ) + enable_background_investigation: Optional[bool] = Field( + True, description="Whether to get background investigation before plan" + ) class TTSRequest(BaseModel): diff --git a/src/workflow.py b/src/workflow.py index ce88e65..fddc8e1 100644 --- a/src/workflow.py +++ b/src/workflow.py @@ -28,6 +28,7 @@ async def run_agent_workflow_async( debug: bool = False, max_plan_iterations: int = 1, max_step_num: int = 3, + enable_background_investigation: bool = True, ): """Run the agent workflow asynchronously with the given user input. @@ -36,6 +37,7 @@ async def run_agent_workflow_async( debug: If True, enables debug level logging max_plan_iterations: Maximum number of plan iterations max_step_num: Maximum number of steps in a plan + enable_background_investigation: If True, performs web search before planning to enhance context Returns: The final state after the workflow completes @@ -51,6 +53,7 @@ async def run_agent_workflow_async( # Runtime Variables "messages": [{"role": "user", "content": user_input}], "auto_accepted_plan": True, + "enable_background_investigation": enable_background_investigation, } config = { "configurable": {