mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-06 07:20:21 +08:00
feat(nodes): add background investigation node
Change-Id: I96e08e22fc7c52647edbf9be4f385a8fae9b449a
This commit is contained in:
27
main.py
27
main.py
@@ -13,7 +13,13 @@ from src.workflow import run_agent_workflow_async
|
||||
from src.config.questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
|
||||
|
||||
|
||||
def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
|
||||
def ask(
|
||||
question,
|
||||
debug=False,
|
||||
max_plan_iterations=1,
|
||||
max_step_num=3,
|
||||
enable_background_investigation=True,
|
||||
):
|
||||
"""Run the agent workflow with the given question.
|
||||
|
||||
Args:
|
||||
@@ -21,6 +27,7 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
|
||||
debug: If True, enables debug level logging
|
||||
max_plan_iterations: Maximum number of plan iterations
|
||||
max_step_num: Maximum number of steps in a plan
|
||||
enable_background_investigation: If True, performs web search before planning to enhance context
|
||||
"""
|
||||
asyncio.run(
|
||||
run_agent_workflow_async(
|
||||
@@ -28,14 +35,21 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
|
||||
debug=debug,
|
||||
max_plan_iterations=max_plan_iterations,
|
||||
max_step_num=max_step_num,
|
||||
enable_background_investigation=enable_background_investigation,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(debug=False, max_plan_iterations=1, max_step_num=3):
|
||||
def main(
|
||||
debug=False,
|
||||
max_plan_iterations=1,
|
||||
max_step_num=3,
|
||||
enable_background_investigation=True,
|
||||
):
|
||||
"""Interactive mode with built-in questions.
|
||||
|
||||
Args:
|
||||
enable_background_investigation: If True, performs web search before planning to enhance context
|
||||
debug: If True, enables debug level logging
|
||||
max_plan_iterations: Maximum number of plan iterations
|
||||
max_step_num: Maximum number of steps in a plan
|
||||
@@ -77,6 +91,7 @@ def main(debug=False, max_plan_iterations=1, max_step_num=3):
|
||||
debug=debug,
|
||||
max_plan_iterations=max_plan_iterations,
|
||||
max_step_num=max_step_num,
|
||||
enable_background_investigation=enable_background_investigation,
|
||||
)
|
||||
|
||||
|
||||
@@ -102,6 +117,12 @@ if __name__ == "__main__":
|
||||
help="Maximum number of steps in a plan (default: 3)",
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
parser.add_argument(
|
||||
"--no-background-investigation",
|
||||
action="store_false",
|
||||
dest="enable_background_investigation",
|
||||
help="Disable background investigation before planning",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -111,6 +132,7 @@ if __name__ == "__main__":
|
||||
debug=args.debug,
|
||||
max_plan_iterations=args.max_plan_iterations,
|
||||
max_step_num=args.max_step_num,
|
||||
enable_background_investigation=args.enable_background_investigation,
|
||||
)
|
||||
else:
|
||||
# Parse user input from command line arguments or user input
|
||||
@@ -125,4 +147,5 @@ if __name__ == "__main__":
|
||||
debug=args.debug,
|
||||
max_plan_iterations=args.max_plan_iterations,
|
||||
max_step_num=args.max_step_num,
|
||||
enable_background_investigation=args.enable_background_investigation,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from .nodes import (
|
||||
researcher_node,
|
||||
coder_node,
|
||||
human_feedback_node,
|
||||
background_investigation_node,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +22,7 @@ def _build_base_graph():
|
||||
builder = StateGraph(State)
|
||||
builder.add_edge(START, "coordinator")
|
||||
builder.add_node("coordinator", coordinator_node)
|
||||
builder.add_node("background_investigator", background_investigation_node)
|
||||
builder.add_node("planner", planner_node)
|
||||
builder.add_node("reporter", reporter_node)
|
||||
builder.add_node("research_team", research_team_node)
|
||||
|
||||
@@ -13,6 +13,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
|
||||
from src.agents.agents import coder_agent, research_agent, create_agent
|
||||
|
||||
from src.tools.search import LoggedTavilySearch
|
||||
from src.tools import (
|
||||
crawl_tool,
|
||||
web_search_tool,
|
||||
@@ -27,6 +28,7 @@ from src.prompts.template import apply_prompt_template
|
||||
from src.utils.json_utils import repair_json_output
|
||||
|
||||
from .types import State
|
||||
from ..config import SEARCH_MAX_RESULTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,13 +44,55 @@ def handoff_to_planner(
|
||||
return
|
||||
|
||||
|
||||
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
|
||||
|
||||
logger.info("background investigation node is running.")
|
||||
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
|
||||
{"query": state["messages"][-1].content}
|
||||
)
|
||||
background_investigation_results = None
|
||||
if isinstance(searched_content, list):
|
||||
background_investigation_results = [
|
||||
{"title": elem["title"], "content": elem["content"]}
|
||||
for elem in searched_content
|
||||
]
|
||||
else:
|
||||
logger.error(f"Tavily search returned malformed response: {searched_content}")
|
||||
return Command(
|
||||
update={
|
||||
"background_investigation_results": json.dumps(
|
||||
background_investigation_results, ensure_ascii=False
|
||||
)
|
||||
},
|
||||
goto="planner",
|
||||
)
|
||||
|
||||
|
||||
def planner_node(
|
||||
state: State, config: RunnableConfig
|
||||
) -> Command[Literal["human_feedback", "reporter"]]:
|
||||
"""Planner node that generate the full plan."""
|
||||
logger.info("Planner generating full plan")
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
messages = apply_prompt_template("planner", state, configurable)
|
||||
|
||||
if (
|
||||
plan_iterations == 0
|
||||
and state.get("enable_background_investigation")
|
||||
and state.get("background_investigation_results")
|
||||
):
|
||||
messages += [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"background investigation results of user query:\n"
|
||||
+ state["background_investigation_results"]
|
||||
+ "\n"
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
if AGENT_LLM_MAP["planner"] == "basic":
|
||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
|
||||
Plan,
|
||||
@@ -56,7 +100,6 @@ def planner_node(
|
||||
)
|
||||
else:
|
||||
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
|
||||
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
|
||||
|
||||
# if the plan iterations is greater than the max plan iterations, return the reporter node
|
||||
if plan_iterations >= configurable.max_plan_iterations:
|
||||
@@ -134,7 +177,9 @@ def human_feedback_node(
|
||||
)
|
||||
|
||||
|
||||
def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
||||
def coordinator_node(
|
||||
state: State,
|
||||
) -> Command[Literal["planner", "background_investigator", "__end__"]]:
|
||||
"""Coordinator node that communicate with customers."""
|
||||
logger.info("Coordinator talking.")
|
||||
messages = apply_prompt_template("coordinator", state)
|
||||
@@ -150,6 +195,9 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
||||
|
||||
if len(response.tool_calls) > 0:
|
||||
goto = "planner"
|
||||
if state.get("enable_background_investigation"):
|
||||
# if the search_before_planning is True, add the web search tool to the planner agent
|
||||
goto = "background_investigator"
|
||||
try:
|
||||
for tool_call in response.tool_calls:
|
||||
if tool_call.get("name", "") != "handoff_to_planner":
|
||||
@@ -160,9 +208,7 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tool calls: {e}")
|
||||
return Command(
|
||||
update={
|
||||
"locale": locale
|
||||
},
|
||||
update={"locale": locale},
|
||||
goto=goto,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,3 +19,5 @@ class State(MessagesState):
|
||||
current_plan: Plan | str = None
|
||||
final_report: str = ""
|
||||
auto_accepted_plan: bool = False
|
||||
enable_background_investigation: bool = True
|
||||
background_investigation_results: str = None
|
||||
|
||||
@@ -63,6 +63,7 @@ async def chat_stream(request: ChatRequest):
|
||||
request.auto_accepted_plan,
|
||||
request.interrupt_feedback,
|
||||
request.mcp_settings,
|
||||
request.enable_background_investigation,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -76,6 +77,7 @@ async def _astream_workflow_generator(
|
||||
auto_accepted_plan: bool,
|
||||
interrupt_feedback: str,
|
||||
mcp_settings: dict,
|
||||
enable_background_investigation,
|
||||
):
|
||||
input_ = {
|
||||
"messages": messages,
|
||||
@@ -84,12 +86,13 @@ async def _astream_workflow_generator(
|
||||
"current_plan": None,
|
||||
"observations": [],
|
||||
"auto_accepted_plan": auto_accepted_plan,
|
||||
"enable_background_investigation": enable_background_investigation,
|
||||
}
|
||||
if not auto_accepted_plan and interrupt_feedback:
|
||||
resume_msg = f"[{interrupt_feedback}]"
|
||||
# add the last message to the resume message
|
||||
if messages:
|
||||
resume_msg += f" {messages[-1]["content"]}"
|
||||
resume_msg += f" {messages[-1]['content']}"
|
||||
input_ = Command(resume=resume_msg)
|
||||
async for agent, _, event_data in graph.astream(
|
||||
input_,
|
||||
|
||||
@@ -47,6 +47,9 @@ class ChatRequest(BaseModel):
|
||||
mcp_settings: Optional[dict] = Field(
|
||||
None, description="MCP settings for the chat request"
|
||||
)
|
||||
enable_background_investigation: Optional[bool] = Field(
|
||||
True, description="Whether to get background investigation before plan"
|
||||
)
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
|
||||
@@ -28,6 +28,7 @@ async def run_agent_workflow_async(
|
||||
debug: bool = False,
|
||||
max_plan_iterations: int = 1,
|
||||
max_step_num: int = 3,
|
||||
enable_background_investigation: bool = True,
|
||||
):
|
||||
"""Run the agent workflow asynchronously with the given user input.
|
||||
|
||||
@@ -36,6 +37,7 @@ async def run_agent_workflow_async(
|
||||
debug: If True, enables debug level logging
|
||||
max_plan_iterations: Maximum number of plan iterations
|
||||
max_step_num: Maximum number of steps in a plan
|
||||
enable_background_investigation: If True, performs web search before planning to enhance context
|
||||
|
||||
Returns:
|
||||
The final state after the workflow completes
|
||||
@@ -51,6 +53,7 @@ async def run_agent_workflow_async(
|
||||
# Runtime Variables
|
||||
"messages": [{"role": "user", "content": user_input}],
|
||||
"auto_accepted_plan": True,
|
||||
"enable_background_investigation": enable_background_investigation,
|
||||
}
|
||||
config = {
|
||||
"configurable": {
|
||||
|
||||
Reference in New Issue
Block a user