mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 04:14:46 +08:00
refactor: refine the graph structure (#283)
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from src.prompts.planner_model import StepType
|
||||||
|
|
||||||
from .types import State
|
from .types import State
|
||||||
from .nodes import (
|
from .nodes import (
|
||||||
@@ -17,6 +18,22 @@ from .nodes import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def continue_to_running_research_team(state: State):
|
||||||
|
current_plan = state.get("current_plan")
|
||||||
|
if not current_plan or not current_plan.steps:
|
||||||
|
return "planner"
|
||||||
|
if all(step.execution_res for step in current_plan.steps):
|
||||||
|
return "planner"
|
||||||
|
for step in current_plan.steps:
|
||||||
|
if not step.execution_res:
|
||||||
|
break
|
||||||
|
if step.step_type and step.step_type == StepType.RESEARCH:
|
||||||
|
return "researcher"
|
||||||
|
if step.step_type and step.step_type == StepType.PROCESSING:
|
||||||
|
return "coder"
|
||||||
|
return "planner"
|
||||||
|
|
||||||
|
|
||||||
def _build_base_graph():
|
def _build_base_graph():
|
||||||
"""Build and return the base state graph with all nodes and edges."""
|
"""Build and return the base state graph with all nodes and edges."""
|
||||||
builder = StateGraph(State)
|
builder = StateGraph(State)
|
||||||
@@ -29,6 +46,12 @@ def _build_base_graph():
|
|||||||
builder.add_node("researcher", researcher_node)
|
builder.add_node("researcher", researcher_node)
|
||||||
builder.add_node("coder", coder_node)
|
builder.add_node("coder", coder_node)
|
||||||
builder.add_node("human_feedback", human_feedback_node)
|
builder.add_node("human_feedback", human_feedback_node)
|
||||||
|
builder.add_edge("background_investigator", "planner")
|
||||||
|
builder.add_conditional_edges(
|
||||||
|
"research_team",
|
||||||
|
continue_to_running_research_team,
|
||||||
|
["planner", "researcher", "coder"],
|
||||||
|
)
|
||||||
builder.add_edge("reporter", END)
|
builder.add_edge("reporter", END)
|
||||||
return builder
|
return builder
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from src.tools import (
|
|||||||
from src.config.agents import AGENT_LLM_MAP
|
from src.config.agents import AGENT_LLM_MAP
|
||||||
from src.config.configuration import Configuration
|
from src.config.configuration import Configuration
|
||||||
from src.llms.llm import get_llm_by_type
|
from src.llms.llm import get_llm_by_type
|
||||||
from src.prompts.planner_model import Plan, StepType
|
from src.prompts.planner_model import Plan
|
||||||
from src.prompts.template import apply_prompt_template
|
from src.prompts.template import apply_prompt_template
|
||||||
from src.utils.json_utils import repair_json_output
|
from src.utils.json_utils import repair_json_output
|
||||||
|
|
||||||
@@ -45,22 +45,24 @@ def handoff_to_planner(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def background_investigation_node(
|
def background_investigation_node(state: State, config: RunnableConfig):
|
||||||
state: State, config: RunnableConfig
|
|
||||||
) -> Command[Literal["planner"]]:
|
|
||||||
logger.info("background investigation node is running.")
|
logger.info("background investigation node is running.")
|
||||||
configurable = Configuration.from_runnable_config(config)
|
configurable = Configuration.from_runnable_config(config)
|
||||||
query = state["messages"][-1].content
|
query = state["messages"][-1].content
|
||||||
|
background_investigation_results = None
|
||||||
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
||||||
searched_content = LoggedTavilySearch(
|
searched_content = LoggedTavilySearch(
|
||||||
max_results=configurable.max_search_results
|
max_results=configurable.max_search_results
|
||||||
).invoke(query)
|
).invoke(query)
|
||||||
background_investigation_results = None
|
|
||||||
if isinstance(searched_content, list):
|
if isinstance(searched_content, list):
|
||||||
background_investigation_results = [
|
background_investigation_results = [
|
||||||
{"title": elem["title"], "content": elem["content"]}
|
f"## {elem['title']}\n\n{elem['content']}" for elem in searched_content
|
||||||
for elem in searched_content
|
|
||||||
]
|
]
|
||||||
|
return {
|
||||||
|
"background_investigation_results": "\n\n".join(
|
||||||
|
background_investigation_results
|
||||||
|
)
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Tavily search returned malformed response: {searched_content}"
|
f"Tavily search returned malformed response: {searched_content}"
|
||||||
@@ -69,14 +71,11 @@ def background_investigation_node(
|
|||||||
background_investigation_results = get_web_search_tool(
|
background_investigation_results = get_web_search_tool(
|
||||||
configurable.max_search_results
|
configurable.max_search_results
|
||||||
).invoke(query)
|
).invoke(query)
|
||||||
return Command(
|
return {
|
||||||
update={
|
"background_investigation_results": json.dumps(
|
||||||
"background_investigation_results": json.dumps(
|
background_investigation_results, ensure_ascii=False
|
||||||
background_investigation_results, ensure_ascii=False
|
)
|
||||||
)
|
}
|
||||||
},
|
|
||||||
goto="planner",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def planner_node(
|
def planner_node(
|
||||||
@@ -287,24 +286,10 @@ def reporter_node(state: State):
|
|||||||
return {"final_report": response_content}
|
return {"final_report": response_content}
|
||||||
|
|
||||||
|
|
||||||
def research_team_node(
|
def research_team_node(state: State):
|
||||||
state: State,
|
|
||||||
) -> Command[Literal["planner", "researcher", "coder"]]:
|
|
||||||
"""Research team node that collaborates on tasks."""
|
"""Research team node that collaborates on tasks."""
|
||||||
logger.info("Research team is collaborating on tasks.")
|
logger.info("Research team is collaborating on tasks.")
|
||||||
current_plan = state.get("current_plan")
|
pass
|
||||||
if not current_plan or not current_plan.steps:
|
|
||||||
return Command(goto="planner")
|
|
||||||
if all(step.execution_res for step in current_plan.steps):
|
|
||||||
return Command(goto="planner")
|
|
||||||
for step in current_plan.steps:
|
|
||||||
if not step.execution_res:
|
|
||||||
break
|
|
||||||
if step.step_type and step.step_type == StepType.RESEARCH:
|
|
||||||
return Command(goto="researcher")
|
|
||||||
if step.step_type and step.step_type == StepType.PROCESSING:
|
|
||||||
return Command(goto="coder")
|
|
||||||
return Command(goto="planner")
|
|
||||||
|
|
||||||
|
|
||||||
async def _execute_agent_step(
|
async def _execute_agent_step(
|
||||||
|
|||||||
@@ -82,27 +82,25 @@ def test_background_investigation_node_tavily(
|
|||||||
result = background_investigation_node(mock_state, mock_config)
|
result = background_investigation_node(mock_state, mock_config)
|
||||||
|
|
||||||
# Verify the result structure
|
# Verify the result structure
|
||||||
assert isinstance(result, Command)
|
assert isinstance(result, dict)
|
||||||
assert result.goto == "planner"
|
|
||||||
|
|
||||||
# Verify the update contains background_investigation_results
|
# Verify the update contains background_investigation_results
|
||||||
update = result.update
|
assert "background_investigation_results" in result
|
||||||
assert "background_investigation_results" in update
|
|
||||||
|
|
||||||
# Parse and verify the JSON content
|
# Parse and verify the JSON content
|
||||||
results = json.loads(update["background_investigation_results"])
|
results = result["background_investigation_results"]
|
||||||
assert isinstance(results, list)
|
|
||||||
|
|
||||||
if search_engine == SearchEngine.TAVILY.value:
|
if search_engine == SearchEngine.TAVILY.value:
|
||||||
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
|
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
|
||||||
assert len(results) == 2
|
assert (
|
||||||
assert results[0]["title"] == "Test Title 1"
|
results
|
||||||
assert results[0]["content"] == "Test Content 1"
|
== "## Test Title 1\n\nTest Content 1\n\n## Test Title 2\n\nTest Content 2"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
mock_web_search_tool.return_value.invoke.assert_called_once_with(
|
mock_web_search_tool.return_value.invoke.assert_called_once_with(
|
||||||
"test query"
|
"test query"
|
||||||
)
|
)
|
||||||
assert len(results) == 2
|
assert len(json.loads(results)) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_background_investigation_node_malformed_response(
|
def test_background_investigation_node_malformed_response(
|
||||||
@@ -116,13 +114,11 @@ def test_background_investigation_node_malformed_response(
|
|||||||
result = background_investigation_node(mock_state, mock_config)
|
result = background_investigation_node(mock_state, mock_config)
|
||||||
|
|
||||||
# Verify the result structure
|
# Verify the result structure
|
||||||
assert isinstance(result, Command)
|
assert isinstance(result, dict)
|
||||||
assert result.goto == "planner"
|
|
||||||
|
|
||||||
# Verify the update contains background_investigation_results
|
# Verify the update contains background_investigation_results
|
||||||
update = result.update
|
assert "background_investigation_results" in result
|
||||||
assert "background_investigation_results" in update
|
|
||||||
|
|
||||||
# Parse and verify the JSON content
|
# Parse and verify the JSON content
|
||||||
results = json.loads(update["background_investigation_results"])
|
results = result["background_investigation_results"]
|
||||||
assert results is None
|
assert json.loads(results) is None
|
||||||
|
|||||||
Reference in New Issue
Block a user