From 447e427fd30780ed87024cd1975bdd01545b950f Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:10:02 +0800 Subject: [PATCH] refactor: refine teh background check logic (#306) --- src/graph/nodes.py | 24 +++++++++++++++--------- src/graph/types.py | 1 + src/server/app.py | 3 ++- tests/integration/test_nodes.py | 1 + 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 7715a26..ec2c31b 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) @tool def handoff_to_planner( - task_title: Annotated[str, "The title of the task to be handed off."], + research_topic: Annotated[str, "The topic of the research task to be handed off."], locale: Annotated[str, "The user's detected language locale (e.g., en-US, zh-CN)."], ): """Handoff to planner agent to do plan.""" @@ -48,7 +48,7 @@ def handoff_to_planner( def background_investigation_node(state: State, config: RunnableConfig): logger.info("background investigation node is running.") configurable = Configuration.from_runnable_config(config) - query = state["messages"][-1].content + query = state.get("research_topic") background_investigation_results = None if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: searched_content = LoggedTavilySearch( @@ -87,10 +87,8 @@ def planner_node( plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 messages = apply_prompt_template("planner", state, configurable) - if ( - plan_iterations == 0 - and state.get("enable_background_investigation") - and state.get("background_investigation_results") + if state.get("enable_background_investigation") and state.get( + "background_investigation_results" ): messages += [ { @@ -221,6 +219,7 @@ def coordinator_node( goto = "__end__" locale = state.get("locale", "en-US") # Default locale if not specified + research_topic = state.get("research_topic", "") if len(response.tool_calls) > 0: goto = "planner" @@ -231,8 +230,11 @@ def coordinator_node( for tool_call in response.tool_calls: if tool_call.get("name", "") != "handoff_to_planner": continue - if tool_locale := tool_call.get("args", {}).get("locale"): - locale = tool_locale + if tool_call.get("args", {}).get("locale") and tool_call.get( + "args", {} + ).get("research_topic"): + locale = tool_call.get("args", {}).get("locale") + research_topic = tool_call.get("args", {}).get("research_topic") break except Exception as e: logger.error(f"Error processing tool calls: {e}") @@ -243,7 +245,11 @@ def coordinator_node( logger.debug(f"Coordinator response: {response}") return Command( - update={"locale": locale, "resources": configurable.resources}, + update={ + "locale": locale, + "research_topic": research_topic, + "resources": configurable.resources, + }, goto=goto, ) diff --git a/src/graph/types.py b/src/graph/types.py index fba8264..84b231b 100644 --- a/src/graph/types.py +++ b/src/graph/types.py @@ -12,6 +12,7 @@ class State(MessagesState): # Runtime Variables locale: str = "en-US" + research_topic: str = "" observations: list[str] = [] resources: list[Resource] = [] plan_iterations: int = 0 diff --git a/src/server/app.py b/src/server/app.py index 782f789..4a8865e 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -87,7 +87,7 @@ async def chat_stream(request: ChatRequest): async def _astream_workflow_generator( - messages: List[ChatMessage], + messages: List[dict], thread_id: str, resources: List[Resource], max_plan_iterations: int, @@ -107,6 +107,7 @@ async def _astream_workflow_generator( "observations": [], "auto_accepted_plan": auto_accepted_plan, "enable_background_investigation": enable_background_investigation, + "research_topic": messages[-1]["content"] if messages else "", } if not auto_accepted_plan and interrupt_feedback: resume_msg = f"[{interrupt_feedback}]" diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 0d3602f..55c8e48 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -20,6 +20,7 @@ MOCK_SEARCH_RESULTS = [ def mock_state(): return { "messages": [HumanMessage(content="test query")], + "research_topic": "test query", "background_investigation_results": None, }