feat(nodes): add background investigation node

Change-Id: I96e08e22fc7c52647edbf9be4f385a8fae9b449a
This commit is contained in:
Zhao Longjie
2025-04-27 20:15:42 +08:00
parent ada5e34eeb
commit 899438eca0
7 changed files with 90 additions and 8 deletions

27
main.py
View File

@@ -13,7 +13,13 @@ from src.workflow import run_agent_workflow_async
from src.config.questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
def ask(
question,
debug=False,
max_plan_iterations=1,
max_step_num=3,
enable_background_investigation=True,
):
"""Run the agent workflow with the given question.
Args:
@@ -21,6 +27,7 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
debug: If True, enables debug level logging
max_plan_iterations: Maximum number of plan iterations
max_step_num: Maximum number of steps in a plan
enable_background_investigation: If True, performs web search before planning to enhance context
"""
asyncio.run(
run_agent_workflow_async(
@@ -28,14 +35,21 @@ def ask(question, debug=False, max_plan_iterations=1, max_step_num=3):
debug=debug,
max_plan_iterations=max_plan_iterations,
max_step_num=max_step_num,
enable_background_investigation=enable_background_investigation,
)
)
def main(debug=False, max_plan_iterations=1, max_step_num=3):
def main(
debug=False,
max_plan_iterations=1,
max_step_num=3,
enable_background_investigation=True,
):
"""Interactive mode with built-in questions.
Args:
enable_background_investigation: If True, performs web search before planning to enhance context
debug: If True, enables debug level logging
max_plan_iterations: Maximum number of plan iterations
max_step_num: Maximum number of steps in a plan
@@ -77,6 +91,7 @@ def main(debug=False, max_plan_iterations=1, max_step_num=3):
debug=debug,
max_plan_iterations=max_plan_iterations,
max_step_num=max_step_num,
enable_background_investigation=enable_background_investigation,
)
@@ -102,6 +117,12 @@ if __name__ == "__main__":
help="Maximum number of steps in a plan (default: 3)",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
parser.add_argument(
"--no-background-investigation",
action="store_false",
dest="enable_background_investigation",
help="Disable background investigation before planning",
)
args = parser.parse_args()
@@ -111,6 +132,7 @@ if __name__ == "__main__":
debug=args.debug,
max_plan_iterations=args.max_plan_iterations,
max_step_num=args.max_step_num,
enable_background_investigation=args.enable_background_investigation,
)
else:
# Parse user input from command line arguments or user input
@@ -125,4 +147,5 @@ if __name__ == "__main__":
debug=args.debug,
max_plan_iterations=args.max_plan_iterations,
max_step_num=args.max_step_num,
enable_background_investigation=args.enable_background_investigation,
)

View File

@@ -13,6 +13,7 @@ from .nodes import (
researcher_node,
coder_node,
human_feedback_node,
background_investigation_node,
)
@@ -21,6 +22,7 @@ def _build_base_graph():
builder = StateGraph(State)
builder.add_edge(START, "coordinator")
builder.add_node("coordinator", coordinator_node)
builder.add_node("background_investigator", background_investigation_node)
builder.add_node("planner", planner_node)
builder.add_node("reporter", reporter_node)
builder.add_node("research_team", research_team_node)

View File

@@ -13,6 +13,7 @@ from langchain_mcp_adapters.client import MultiServerMCPClient
from src.agents.agents import coder_agent, research_agent, create_agent
from src.tools.search import LoggedTavilySearch
from src.tools import (
crawl_tool,
web_search_tool,
@@ -27,6 +28,7 @@ from src.prompts.template import apply_prompt_template
from src.utils.json_utils import repair_json_output
from .types import State
from ..config import SEARCH_MAX_RESULTS
logger = logging.getLogger(__name__)
@@ -42,13 +44,55 @@ def handoff_to_planner(
return
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
logger.info("background investigation node is running.")
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
{"query": state["messages"][-1].content}
)
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
{"title": elem["title"], "content": elem["content"]}
for elem in searched_content
]
else:
logger.error(f"Tavily search returned malformed response: {searched_content}")
return Command(
update={
"background_investigation_results": json.dumps(
background_investigation_results, ensure_ascii=False
)
},
goto="planner",
)
def planner_node(
state: State, config: RunnableConfig
) -> Command[Literal["human_feedback", "reporter"]]:
"""Planner node that generate the full plan."""
logger.info("Planner generating full plan")
configurable = Configuration.from_runnable_config(config)
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
messages = apply_prompt_template("planner", state, configurable)
if (
plan_iterations == 0
and state.get("enable_background_investigation")
and state.get("background_investigation_results")
):
messages += [
{
"role": "user",
"content": (
"background investigation results of user query:\n"
+ state["background_investigation_results"]
+ "\n"
),
}
]
if AGENT_LLM_MAP["planner"] == "basic":
llm = get_llm_by_type(AGENT_LLM_MAP["planner"]).with_structured_output(
Plan,
@@ -56,7 +100,6 @@ def planner_node(
)
else:
llm = get_llm_by_type(AGENT_LLM_MAP["planner"])
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0
# if the plan iterations is greater than the max plan iterations, return the reporter node
if plan_iterations >= configurable.max_plan_iterations:
@@ -134,7 +177,9 @@ def human_feedback_node(
)
def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
def coordinator_node(
state: State,
) -> Command[Literal["planner", "background_investigator", "__end__"]]:
"""Coordinator node that communicate with customers."""
logger.info("Coordinator talking.")
messages = apply_prompt_template("coordinator", state)
@@ -150,6 +195,9 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
if len(response.tool_calls) > 0:
goto = "planner"
if state.get("enable_background_investigation"):
# if the search_before_planning is True, add the web search tool to the planner agent
goto = "background_investigator"
try:
for tool_call in response.tool_calls:
if tool_call.get("name", "") != "handoff_to_planner":
@@ -160,9 +208,7 @@ def coordinator_node(state: State) -> Command[Literal["planner", "__end__"]]:
except Exception as e:
logger.error(f"Error processing tool calls: {e}")
return Command(
update={
"locale": locale
},
update={"locale": locale},
goto=goto,
)

View File

@@ -19,3 +19,5 @@ class State(MessagesState):
current_plan: Plan | str = None
final_report: str = ""
auto_accepted_plan: bool = False
enable_background_investigation: bool = True
background_investigation_results: str = None

View File

@@ -63,6 +63,7 @@ async def chat_stream(request: ChatRequest):
request.auto_accepted_plan,
request.interrupt_feedback,
request.mcp_settings,
request.enable_background_investigation,
),
media_type="text/event-stream",
)
@@ -76,6 +77,7 @@ async def _astream_workflow_generator(
auto_accepted_plan: bool,
interrupt_feedback: str,
mcp_settings: dict,
enable_background_investigation,
):
input_ = {
"messages": messages,
@@ -84,12 +86,13 @@ async def _astream_workflow_generator(
"current_plan": None,
"observations": [],
"auto_accepted_plan": auto_accepted_plan,
"enable_background_investigation": enable_background_investigation,
}
if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]"
# add the last message to the resume message
if messages:
resume_msg += f" {messages[-1]["content"]}"
resume_msg += f" {messages[-1]['content']}"
input_ = Command(resume=resume_msg)
async for agent, _, event_data in graph.astream(
input_,

View File

@@ -47,6 +47,9 @@ class ChatRequest(BaseModel):
mcp_settings: Optional[dict] = Field(
None, description="MCP settings for the chat request"
)
enable_background_investigation: Optional[bool] = Field(
True, description="Whether to get background investigation before plan"
)
class TTSRequest(BaseModel):

View File

@@ -28,6 +28,7 @@ async def run_agent_workflow_async(
debug: bool = False,
max_plan_iterations: int = 1,
max_step_num: int = 3,
enable_background_investigation: bool = True,
):
"""Run the agent workflow asynchronously with the given user input.
@@ -36,6 +37,7 @@ async def run_agent_workflow_async(
debug: If True, enables debug level logging
max_plan_iterations: Maximum number of plan iterations
max_step_num: Maximum number of steps in a plan
enable_background_investigation: If True, performs web search before planning to enhance context
Returns:
The final state after the workflow completes
@@ -51,6 +53,7 @@ async def run_agent_workflow_async(
# Runtime Variables
"messages": [{"role": "user", "content": user_input}],
"auto_accepted_plan": True,
"enable_background_investigation": enable_background_investigation,
}
config = {
"configurable": {