mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-16 03:14:45 +08:00
refactor: refine the graph structure (#283)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from src.prompts.planner_model import StepType
|
||||
|
||||
from .types import State
|
||||
from .nodes import (
|
||||
@@ -17,6 +18,22 @@ from .nodes import (
|
||||
)
|
||||
|
||||
|
||||
def continue_to_running_research_team(state: State):
|
||||
current_plan = state.get("current_plan")
|
||||
if not current_plan or not current_plan.steps:
|
||||
return "planner"
|
||||
if all(step.execution_res for step in current_plan.steps):
|
||||
return "planner"
|
||||
for step in current_plan.steps:
|
||||
if not step.execution_res:
|
||||
break
|
||||
if step.step_type and step.step_type == StepType.RESEARCH:
|
||||
return "researcher"
|
||||
if step.step_type and step.step_type == StepType.PROCESSING:
|
||||
return "coder"
|
||||
return "planner"
|
||||
|
||||
|
||||
def _build_base_graph():
|
||||
"""Build and return the base state graph with all nodes and edges."""
|
||||
builder = StateGraph(State)
|
||||
@@ -29,6 +46,12 @@ def _build_base_graph():
|
||||
builder.add_node("researcher", researcher_node)
|
||||
builder.add_node("coder", coder_node)
|
||||
builder.add_node("human_feedback", human_feedback_node)
|
||||
builder.add_edge("background_investigator", "planner")
|
||||
builder.add_conditional_edges(
|
||||
"research_team",
|
||||
continue_to_running_research_team,
|
||||
["planner", "researcher", "coder"],
|
||||
)
|
||||
builder.add_edge("reporter", END)
|
||||
return builder
|
||||
|
||||
|
||||
Reference in New Issue
Block a user