refactor: refine the graph structure (#283)

This commit is contained in:
DanielWalnut
2025-06-04 21:47:17 -07:00
committed by GitHub
parent 73ac8ae45a
commit b5ec61bb9d
3 changed files with 51 additions and 47 deletions

View File

@@ -3,6 +3,7 @@
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from src.prompts.planner_model import StepType
from .types import State
from .nodes import (
@@ -17,6 +18,22 @@ from .nodes import (
)
def continue_to_running_research_team(state: State):
current_plan = state.get("current_plan")
if not current_plan or not current_plan.steps:
return "planner"
if all(step.execution_res for step in current_plan.steps):
return "planner"
for step in current_plan.steps:
if not step.execution_res:
break
if step.step_type and step.step_type == StepType.RESEARCH:
return "researcher"
if step.step_type and step.step_type == StepType.PROCESSING:
return "coder"
return "planner"
def _build_base_graph():
"""Build and return the base state graph with all nodes and edges."""
builder = StateGraph(State)
@@ -29,6 +46,12 @@ def _build_base_graph():
builder.add_node("researcher", researcher_node)
builder.add_node("coder", coder_node)
builder.add_node("human_feedback", human_feedback_node)
builder.add_edge("background_investigator", "planner")
builder.add_conditional_edges(
"research_team",
continue_to_running_research_team,
["planner", "researcher", "coder"],
)
builder.add_edge("reporter", END)
return builder

View File

@@ -24,7 +24,7 @@ from src.tools import (
from src.config.agents import AGENT_LLM_MAP
from src.config.configuration import Configuration
from src.llms.llm import get_llm_by_type
from src.prompts.planner_model import Plan, StepType
from src.prompts.planner_model import Plan
from src.prompts.template import apply_prompt_template
from src.utils.json_utils import repair_json_output
@@ -45,22 +45,24 @@ def handoff_to_planner(
return
def background_investigation_node(
state: State, config: RunnableConfig
) -> Command[Literal["planner"]]:
def background_investigation_node(state: State, config: RunnableConfig):
logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content
background_investigation_results = None
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
searched_content = LoggedTavilySearch(
max_results=configurable.max_search_results
).invoke(query)
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
{"title": elem["title"], "content": elem["content"]}
for elem in searched_content
f"## {elem['title']}\n\n{elem['content']}" for elem in searched_content
]
return {
"background_investigation_results": "\n\n".join(
background_investigation_results
)
}
else:
logger.error(
f"Tavily search returned malformed response: {searched_content}"
@@ -69,14 +71,11 @@ def background_investigation_node(
background_investigation_results = get_web_search_tool(
configurable.max_search_results
).invoke(query)
return Command(
update={
"background_investigation_results": json.dumps(
background_investigation_results, ensure_ascii=False
)
},
goto="planner",
)
return {
"background_investigation_results": json.dumps(
background_investigation_results, ensure_ascii=False
)
}
def planner_node(
@@ -287,24 +286,10 @@ def reporter_node(state: State):
return {"final_report": response_content}
def research_team_node(
state: State,
) -> Command[Literal["planner", "researcher", "coder"]]:
def research_team_node(state: State):
"""Research team node that collaborates on tasks."""
logger.info("Research team is collaborating on tasks.")
current_plan = state.get("current_plan")
if not current_plan or not current_plan.steps:
return Command(goto="planner")
if all(step.execution_res for step in current_plan.steps):
return Command(goto="planner")
for step in current_plan.steps:
if not step.execution_res:
break
if step.step_type and step.step_type == StepType.RESEARCH:
return Command(goto="researcher")
if step.step_type and step.step_type == StepType.PROCESSING:
return Command(goto="coder")
return Command(goto="planner")
pass
async def _execute_agent_step(