mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-14 18:54:46 +08:00
feat(nodes): add background investigation node
Change-Id: I96e08e22fc7c52647edbf9be4f385a8fae9b449a
This commit is contained in:
@@ -13,6 +13,7 @@ from .nodes import (
|
||||
researcher_node,
|
||||
coder_node,
|
||||
human_feedback_node,
|
||||
background_investigation_node,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +22,7 @@ def _build_base_graph():
|
||||
builder = StateGraph(State)
|
||||
builder.add_edge(START, "coordinator")
|
||||
builder.add_node("coordinator", coordinator_node)
|
||||
builder.add_node("background_investigator", background_investigation_node)
|
||||
builder.add_node("planner", planner_node)
|
||||
builder.add_node("reporter", reporter_node)
|
||||
builder.add_node("research_team", research_team_node)
|
||||
|
||||
@@ -13,6 +13,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
|
||||
from src.agents.agents import coder_agent, research_agent, create_agent
|
||||
|
||||
from src.tools.search import LoggedTavilySearch
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
web_search_tool,
|
||||
@@ -27,6 +28,7 @@ from src.prompts.template import apply_prompt_template
|
||||
from src.utils.json_utils import repair_json_output
|
||||
|
||||
from .types import State
|
||||
from ..config import SEARCH_MAX_RESULTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,13 +44,55 @@ def handoff_to_planner(
|
||||
return
|
||||
|
||||
|
||||
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
|
||||
|
||||
logger.info("background investigation node is running.")
|
||||
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
|
||||
{"query": state["messages"][-1].content}
|
||||
)
|
||||
background_investigation_results = None
|
||||
if isinstance(searched_content, list):
|
||||
background_investigation_results = [
|
||||
{"title": elem["title"], "content": elem["content"]}
|
||||
for elem in searched_content
|
||||
]
|
||||
else:
|
||||
logger.error(f"Tavily search returned malformed response: {searched_content}")
|
||||
return Command(
|
||||
update={
|
||||
"background_investigation_results": json.dumps(
|
||||
background_investigation_results, ensure_ascii=False
|
||||
)
|
||||
},
|
||||
goto="planner",
|
||||
)
|
||||
|
||||
|
||||
def planner_node(
|
||||
state: State, config: RunnableConfig
|
||||
) -> Command[Literal["human_feedback", "reporter"]]:
|
||||
"""Planner node that generate the full plan."""
|
||||
logger.info("Planner generating full plan")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
messages = apply_prompt_template("planner", state, configurable)
|
||||
|
||||
if (
|
||||
plan_iterations == 0
|
||||
and state.get("enable_background_investigation")
|
||||
and state.get("background_investigation_results")
|
||||
):
|
||||
messages += [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"background investigation results of user query:\n"
|
||||
+ state["background_investigation_results"]
|
||||
+ "\n"
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
if AGENT_LLM_MAP["planner"] == "basic":
|
||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
|
||||
Plan,
|
||||
@@ -56,7 +100,6 @@ def planner_node(
|
||||
)
|
||||
else:
|
||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
|
||||
# if the plan iterations is greater than the max plan iterations, return the reporter node
|
||||
if plan_iterations >= configurable.max_plan_iterations:
|
||||
@@ -134,7 +177,9 @@ def human_feedback_node(
|
||||
)
|
||||
|
||||
|
||||
def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
||||
def coordinator_node(
|
||||
state: State,
|
||||
) -> Command[Literal["planner", "background_investigator", "__end__"]]:
|
||||
"""Coordinator node that communicate with customers."""
|
||||
logger.info("Coordinator talking.")
|
||||
messages = apply_prompt_template("coordinator", state)
|
||||
@@ -150,6 +195,9 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
||||
|
||||
if len(response.tool_calls) > 0:
|
||||
goto = "planner"
|
||||
if state.get("enable_background_investigation"):
|
||||
# if the search_before_planning is True, add the web search tool to the planner agent
|
||||
goto = "background_investigator"
|
||||
try:
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.get("name", "") != "handoff_to_planner":
|
||||
@@ -160,9 +208,7 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tool calls: {e}")
|
||||
return Command(
|
||||
update={
|
||||
"locale": locale
|
||||
},
|
||||
update={"locale": locale},
|
||||
goto=goto,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,3 +19,5 @@ class State(MessagesState):
|
||||
current_plan: Plan | str = None
|
||||
final_report: str = ""
|
||||
auto_accepted_plan: bool = False
|
||||
enable_background_investigation: bool = True
|
||||
background_investigation_results: str = None
|
||||
|
||||
Reference in New Issue
Block a user