mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-25 15:04:46 +08:00
feat(nodes): add background investigation node
Change-Id: I96e08e22fc7c52647edbf9be4f385a8fae9b449a
This commit is contained in:
27
main.py
27
main.py
@@ -13,7 +13,13 @@ from src.workflow import run_agent_workflow_async
|
|||||||
from src.config.questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
|
from src.config.questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
|
||||||
|
|
||||||
|
|
||||||
def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
|
def ask(
|
||||||
|
question,
|
||||||
|
debug=False,
|
||||||
|
max_plan_iterations=1,
|
||||||
|
max_step_num=3,
|
||||||
|
enable_background_investigation=True,
|
||||||
|
):
|
||||||
"""Run the agent workflow with the given question.
|
"""Run the agent workflow with the given question.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -21,6 +27,7 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
|
|||||||
debug: If True, enables debug level logging
|
debug: If True, enables debug level logging
|
||||||
max_plan_iterations: Maximum number of plan iterations
|
max_plan_iterations: Maximum number of plan iterations
|
||||||
max_step_num: Maximum number of steps in a plan
|
max_step_num: Maximum number of steps in a plan
|
||||||
|
enable_background_investigation: If True, performs web search before planning to enhance context
|
||||||
"""
|
"""
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
run_agent_workflow_async(
|
run_agent_workflow_async(
|
||||||
@@ -28,14 +35,21 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
|
|||||||
debug=debug,
|
debug=debug,
|
||||||
max_plan_iterations=max_plan_iterations,
|
max_plan_iterations=max_plan_iterations,
|
||||||
max_step_num=max_step_num,
|
max_step_num=max_step_num,
|
||||||
|
enable_background_investigation=enable_background_investigation,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(debug=False, max_plan_iterations=1, max_step_num=3):
|
def main(
|
||||||
|
debug=False,
|
||||||
|
max_plan_iterations=1,
|
||||||
|
max_step_num=3,
|
||||||
|
enable_background_investigation=True,
|
||||||
|
):
|
||||||
"""Interactive mode with built-in questions.
|
"""Interactive mode with built-in questions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
enable_background_investigation: If True, performs web search before planning to enhance context
|
||||||
debug: If True, enables debug level logging
|
debug: If True, enables debug level logging
|
||||||
max_plan_iterations: Maximum number of plan iterations
|
max_plan_iterations: Maximum number of plan iterations
|
||||||
max_step_num: Maximum number of steps in a plan
|
max_step_num: Maximum number of steps in a plan
|
||||||
@@ -77,6 +91,7 @@ def main(debug=False, max_plan_iterations=1, max_step_num=3):
|
|||||||
debug=debug,
|
debug=debug,
|
||||||
max_plan_iterations=max_plan_iterations,
|
max_plan_iterations=max_plan_iterations,
|
||||||
max_step_num=max_step_num,
|
max_step_num=max_step_num,
|
||||||
|
enable_background_investigation=enable_background_investigation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -102,6 +117,12 @@ if __name__ == "__main__":
|
|||||||
help="Maximum number of steps in a plan (default: 3)",
|
help="Maximum number of steps in a plan (default: 3)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-background-investigation",
|
||||||
|
action="store_false",
|
||||||
|
dest="enable_background_investigation",
|
||||||
|
help="Disable background investigation before planning",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -111,6 +132,7 @@ if __name__ == "__main__":
|
|||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
max_plan_iterations=args.max_plan_iterations,
|
max_plan_iterations=args.max_plan_iterations,
|
||||||
max_step_num=args.max_step_num,
|
max_step_num=args.max_step_num,
|
||||||
|
enable_background_investigation=args.enable_background_investigation,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Parse user input from command line arguments or user input
|
# Parse user input from command line arguments or user input
|
||||||
@@ -125,4 +147,5 @@ if __name__ == "__main__":
|
|||||||
debug=args.debug,
|
debug=args.debug,
|
||||||
max_plan_iterations=args.max_plan_iterations,
|
max_plan_iterations=args.max_plan_iterations,
|
||||||
max_step_num=args.max_step_num,
|
max_step_num=args.max_step_num,
|
||||||
|
enable_background_investigation=args.enable_background_investigation,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from .nodes import (
|
|||||||
researcher_node,
|
researcher_node,
|
||||||
coder_node,
|
coder_node,
|
||||||
human_feedback_node,
|
human_feedback_node,
|
||||||
|
background_investigation_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ def _build_base_graph():
|
|||||||
builder = StateGraph(State)
|
builder = StateGraph(State)
|
||||||
builder.add_edge(START, "coordinator")
|
builder.add_edge(START, "coordinator")
|
||||||
builder.add_node("coordinator", coordinator_node)
|
builder.add_node("coordinator", coordinator_node)
|
||||||
|
builder.add_node("background_investigator", background_investigation_node)
|
||||||
builder.add_node("planner", planner_node)
|
builder.add_node("planner", planner_node)
|
||||||
builder.add_node("reporter", reporter_node)
|
builder.add_node("reporter", reporter_node)
|
||||||
builder.add_node("research_team", research_team_node)
|
builder.add_node("research_team", research_team_node)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient
|
|||||||
|
|
||||||
from src.agents.agents import coder_agent, research_agent, create_agent
|
from src.agents.agents import coder_agent, research_agent, create_agent
|
||||||
|
|
||||||
|
from src.tools.search import LoggedTavilySearch
|
||||||
from src.tools import (
|
from src.tools import (
|
||||||
crawl_tool,
|
crawl_tool,
|
||||||
web_search_tool,
|
web_search_tool,
|
||||||
@@ -27,6 +28,7 @@ from src.prompts.template import apply_prompt_template
|
|||||||
from src.utils.json_utils import repair_json_output
|
from src.utils.json_utils import repair_json_output
|
||||||
|
|
||||||
from .types import State
|
from .types import State
|
||||||
|
from ..config import SEARCH_MAX_RESULTS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,13 +44,55 @@ def handoff_to_planner(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
|
||||||
|
|
||||||
|
logger.info("background investigation node is running.")
|
||||||
|
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
|
||||||
|
{"query": state["messages"][-1].content}
|
||||||
|
)
|
||||||
|
background_investigation_results = None
|
||||||
|
if isinstance(searched_content, list):
|
||||||
|
background_investigation_results = [
|
||||||
|
{"title": elem["title"], "content": elem["content"]}
|
||||||
|
for elem in searched_content
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
logger.error(f"Tavily search returned malformed response: {searched_content}")
|
||||||
|
return Command(
|
||||||
|
update={
|
||||||
|
"background_investigation_results": json.dumps(
|
||||||
|
background_investigation_results, ensure_ascii=False
|
||||||
|
)
|
||||||
|
},
|
||||||
|
goto="planner",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def planner_node(
|
def planner_node(
|
||||||
state: State, config: RunnableConfig
|
state: State, config: RunnableConfig
|
||||||
) -> Command[Literal["human_feedback", "reporter"]]:
|
) -> Command[Literal["human_feedback", "reporter"]]:
|
||||||
"""Planner node that generate the full plan."""
|
"""Planner node that generate the full plan."""
|
||||||
logger.info("Planner generating full plan")
|
logger.info("Planner generating full plan")
|
||||||
configurable = Configuration.from_runnable_config(config)
|
configurable = Configuration.from_runnable_config(config)
|
||||||
|
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||||
messages = apply_prompt_template("planner", state, configurable)
|
messages = apply_prompt_template("planner", state, configurable)
|
||||||
|
|
||||||
|
if (
|
||||||
|
plan_iterations == 0
|
||||||
|
and state.get("enable_background_investigation")
|
||||||
|
and state.get("background_investigation_results")
|
||||||
|
):
|
||||||
|
messages += [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"background investigation results of user query:\n"
|
||||||
|
+ state["background_investigation_results"]
|
||||||
|
+ "\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
if AGENT_LLM_MAP["planner"] == "basic":
|
if AGENT_LLM_MAP["planner"] == "basic":
|
||||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
|
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
|
||||||
Plan,
|
Plan,
|
||||||
@@ -56,7 +100,6 @@ def planner_node(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
|
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
|
||||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
|
||||||
|
|
||||||
# if the plan iterations is greater than the max plan iterations, return the reporter node
|
# if the plan iterations is greater than the max plan iterations, return the reporter node
|
||||||
if plan_iterations >= configurable.max_plan_iterations:
|
if plan_iterations >= configurable.max_plan_iterations:
|
||||||
@@ -134,7 +177,9 @@ def human_feedback_node(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
def coordinator_node(
|
||||||
|
state: State,
|
||||||
|
) -> Command[Literal["planner", "background_investigator", "__end__"]]:
|
||||||
"""Coordinator node that communicate with customers."""
|
"""Coordinator node that communicate with customers."""
|
||||||
logger.info("Coordinator talking.")
|
logger.info("Coordinator talking.")
|
||||||
messages = apply_prompt_template("coordinator", state)
|
messages = apply_prompt_template("coordinator", state)
|
||||||
@@ -150,6 +195,9 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
|||||||
|
|
||||||
if len(response.tool_calls) > 0:
|
if len(response.tool_calls) > 0:
|
||||||
goto = "planner"
|
goto = "planner"
|
||||||
|
if state.get("enable_background_investigation"):
|
||||||
|
# if the search_before_planning is True, add the web search tool to the planner agent
|
||||||
|
goto = "background_investigator"
|
||||||
try:
|
try:
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
if tool_call.get("name", "") != "handoff_to_planner":
|
if tool_call.get("name", "") != "handoff_to_planner":
|
||||||
@@ -160,9 +208,7 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing tool calls: {e}")
|
logger.error(f"Error processing tool calls: {e}")
|
||||||
return Command(
|
return Command(
|
||||||
update={
|
update={"locale": locale},
|
||||||
"locale": locale
|
|
||||||
},
|
|
||||||
goto=goto,
|
goto=goto,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,3 +19,5 @@ class State(MessagesState):
|
|||||||
current_plan: Plan | str = None
|
current_plan: Plan | str = None
|
||||||
final_report: str = ""
|
final_report: str = ""
|
||||||
auto_accepted_plan: bool = False
|
auto_accepted_plan: bool = False
|
||||||
|
enable_background_investigation: bool = True
|
||||||
|
background_investigation_results: str = None
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ async def chat_stream(request: ChatRequest):
|
|||||||
request.auto_accepted_plan,
|
request.auto_accepted_plan,
|
||||||
request.interrupt_feedback,
|
request.interrupt_feedback,
|
||||||
request.mcp_settings,
|
request.mcp_settings,
|
||||||
|
request.enable_background_investigation,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
@@ -76,6 +77,7 @@ async def _astream_workflow_generator(
|
|||||||
auto_accepted_plan: bool,
|
auto_accepted_plan: bool,
|
||||||
interrupt_feedback: str,
|
interrupt_feedback: str,
|
||||||
mcp_settings: dict,
|
mcp_settings: dict,
|
||||||
|
enable_background_investigation,
|
||||||
):
|
):
|
||||||
input_ = {
|
input_ = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -84,12 +86,13 @@ async def _astream_workflow_generator(
|
|||||||
"current_plan": None,
|
"current_plan": None,
|
||||||
"observations": [],
|
"observations": [],
|
||||||
"auto_accepted_plan": auto_accepted_plan,
|
"auto_accepted_plan": auto_accepted_plan,
|
||||||
|
"enable_background_investigation": enable_background_investigation,
|
||||||
}
|
}
|
||||||
if not auto_accepted_plan and interrupt_feedback:
|
if not auto_accepted_plan and interrupt_feedback:
|
||||||
resume_msg = f"[{interrupt_feedback}]"
|
resume_msg = f"[{interrupt_feedback}]"
|
||||||
# add the last message to the resume message
|
# add the last message to the resume message
|
||||||
if messages:
|
if messages:
|
||||||
resume_msg += f" {messages[-1]["content"]}"
|
resume_msg += f" {messages[-1]['content']}"
|
||||||
input_ = Command(resume=resume_msg)
|
input_ = Command(resume=resume_msg)
|
||||||
async for agent, _, event_data in graph.astream(
|
async for agent, _, event_data in graph.astream(
|
||||||
input_,
|
input_,
|
||||||
|
|||||||
@@ -47,6 +47,9 @@ class ChatRequest(BaseModel):
|
|||||||
mcp_settings: Optional[dict] = Field(
|
mcp_settings: Optional[dict] = Field(
|
||||||
None, description="MCP settings for the chat request"
|
None, description="MCP settings for the chat request"
|
||||||
)
|
)
|
||||||
|
enable_background_investigation: Optional[bool] = Field(
|
||||||
|
True, description="Whether to get background investigation before plan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TTSRequest(BaseModel):
|
class TTSRequest(BaseModel):
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ async def run_agent_workflow_async(
|
|||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
max_plan_iterations: int = 1,
|
max_plan_iterations: int = 1,
|
||||||
max_step_num: int = 3,
|
max_step_num: int = 3,
|
||||||
|
enable_background_investigation: bool = True,
|
||||||
):
|
):
|
||||||
"""Run the agent workflow asynchronously with the given user input.
|
"""Run the agent workflow asynchronously with the given user input.
|
||||||
|
|
||||||
@@ -36,6 +37,7 @@ async def run_agent_workflow_async(
|
|||||||
debug: If True, enables debug level logging
|
debug: If True, enables debug level logging
|
||||||
max_plan_iterations: Maximum number of plan iterations
|
max_plan_iterations: Maximum number of plan iterations
|
||||||
max_step_num: Maximum number of steps in a plan
|
max_step_num: Maximum number of steps in a plan
|
||||||
|
enable_background_investigation: If True, performs web search before planning to enhance context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The final state after the workflow completes
|
The final state after the workflow completes
|
||||||
@@ -51,6 +53,7 @@ async def run_agent_workflow_async(
|
|||||||
# Runtime Variables
|
# Runtime Variables
|
||||||
"messages": [{"role": "user", "content": user_input}],
|
"messages": [{"role": "user", "content": user_input}],
|
||||||
"auto_accepted_plan": True,
|
"auto_accepted_plan": True,
|
||||||
|
"enable_background_investigation": enable_background_investigation,
|
||||||
}
|
}
|
||||||
config = {
|
config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
|
|||||||
Reference in New Issue
Block a user