From 14005b42553ab2d75a6fbcab6e34317c3f12d27c Mon Sep 17 00:00:00 2001 From: hetao Date: Thu, 5 Jun 2025 12:32:33 +0800 Subject: [PATCH] refactor: refine the graph structure --- src/graph/builder.py | 23 ++++++++++++++++ src/graph/nodes.py | 47 +++++++++++---------------------- tests/integration/test_nodes.py | 28 +++++++++----------- 3 files changed, 51 insertions(+), 47 deletions(-) diff --git a/src/graph/builder.py b/src/graph/builder.py index dc939fa..8be3956 100644 --- a/src/graph/builder.py +++ b/src/graph/builder.py @@ -3,6 +3,7 @@ from langgraph.graph import StateGraph, START, END from langgraph.checkpoint.memory import MemorySaver +from src.prompts.planner_model import StepType from .types import State from .nodes import ( @@ -17,6 +18,22 @@ from .nodes import ( ) +def continue_to_running_research_step(state: State): + current_plan = state.get("current_plan") + if not current_plan or not current_plan.steps: + return "planner" + if all(step.execution_res for step in current_plan.steps): + return "planner" + for step in current_plan.steps: + if not step.execution_res: + break + if step.step_type and step.step_type == StepType.RESEARCH: + return "researcher" + if step.step_type and step.step_type == StepType.PROCESSING: + return "coder" + return "planner" + + def _build_base_graph(): """Build and return the base state graph with all nodes and edges.""" builder = StateGraph(State) @@ -29,6 +46,12 @@ def _build_base_graph(): builder.add_node("researcher", researcher_node) builder.add_node("coder", coder_node) builder.add_node("human_feedback", human_feedback_node) + builder.add_edge("background_investigator", "planner") + builder.add_conditional_edges( + "research_team", + continue_to_running_research_step, + ["planner", "researcher", "coder"], + ) builder.add_edge("reporter", END) return builder diff --git a/src/graph/nodes.py b/src/graph/nodes.py index c0eff53..8c3235d 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -24,7 +24,7 @@ from src.tools import ( from src.config.agents import AGENT_LLM_MAP from src.config.configuration import Configuration from src.llms.llm import get_llm_by_type -from src.prompts.planner_model import Plan, StepType +from src.prompts.planner_model import Plan from src.prompts.template import apply_prompt_template from src.utils.json_utils import repair_json_output @@ -45,22 +45,24 @@ def handoff_to_planner( return -def background_investigation_node( - state: State, config: RunnableConfig -) -> Command[Literal["planner"]]: +def background_investigation_node(state: State, config: RunnableConfig): logger.info("background investigation node is running.") configurable = Configuration.from_runnable_config(config) query = state["messages"][-1].content + background_investigation_results = None if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: searched_content = LoggedTavilySearch( max_results=configurable.max_search_results ).invoke(query) - background_investigation_results = None if isinstance(searched_content, list): background_investigation_results = [ - {"title": elem["title"], "content": elem["content"]} - for elem in searched_content + f"## {elem['title']}\n\n{elem['content']}" for elem in searched_content ] + return { + "background_investigation_results": "\n\n".join( + background_investigation_results + ) + } else: logger.error( f"Tavily search returned malformed response: {searched_content}" @@ -69,14 +71,11 @@ def background_investigation_node( background_investigation_results = get_web_search_tool( configurable.max_search_results ).invoke(query) - return Command( - update={ - "background_investigation_results": json.dumps( - background_investigation_results, ensure_ascii=False - ) - }, - goto="planner", - ) + return { + "background_investigation_results": json.dumps( + background_investigation_results, ensure_ascii=False + ) + } def planner_node( @@ -287,24 +286,10 @@ def reporter_node(state: State): return {"final_report": response_content} -def research_team_node( - state: State, -) -> Command[Literal["planner", "researcher", "coder"]]: +def research_team_node(state: State): """Research team node that collaborates on tasks.""" logger.info("Research team is collaborating on tasks.") - current_plan = state.get("current_plan") - if not current_plan or not current_plan.steps: - return Command(goto="planner") - if all(step.execution_res for step in current_plan.steps): - return Command(goto="planner") - for step in current_plan.steps: - if not step.execution_res: - break - if step.step_type and step.step_type == StepType.RESEARCH: - return Command(goto="researcher") - if step.step_type and step.step_type == StepType.PROCESSING: - return Command(goto="coder") - return Command(goto="planner") + pass async def _execute_agent_step( diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 14c14a8..0d3602f 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -82,27 +82,25 @@ def test_background_investigation_node_tavily( result = background_investigation_node(mock_state, mock_config) # Verify the result structure - assert isinstance(result, Command) - assert result.goto == "planner" + assert isinstance(result, dict) # Verify the update contains background_investigation_results - update = result.update - assert "background_investigation_results" in update + assert "background_investigation_results" in result # Parse and verify the JSON content - results = json.loads(update["background_investigation_results"]) - assert isinstance(results, list) + results = result["background_investigation_results"] if search_engine == SearchEngine.TAVILY.value: mock_tavily_search.return_value.invoke.assert_called_once_with("test query") - assert len(results) == 2 - assert results[0]["title"] == "Test Title 1" - assert results[0]["content"] == "Test Content 1" + assert ( + results + == "## Test Title 1\n\nTest Content 1\n\n## Test Title 2\n\nTest Content 2" + ) else: mock_web_search_tool.return_value.invoke.assert_called_once_with( "test query" ) - assert len(results) == 2 + assert len(json.loads(results)) == 2 def test_background_investigation_node_malformed_response( @@ -116,13 +114,11 @@ def test_background_investigation_node_malformed_response( result = background_investigation_node(mock_state, mock_config) # Verify the result structure - assert isinstance(result, Command) - assert result.goto == "planner" + assert isinstance(result, dict) # Verify the update contains background_investigation_results - update = result.update - assert "background_investigation_results" in update + assert "background_investigation_results" in result # Parse and verify the JSON content - results = json.loads(update["background_investigation_results"]) - assert results is None + results = result["background_investigation_results"] + assert json.loads(results) is None