Files
deer-flow/tests/integration/test_nodes.py
jimmyuconn1982 2510cc61de feat: Add intelligent clarification feature in coordinate step for research queries (#613)
* fix: support local models by making thought field optional in Plan model

- Make thought field optional in Plan model to fix Pydantic validation errors with local models
- Add Ollama configuration example to conf.yaml.example
- Update documentation to include local model support
- Improve planner prompt with better JSON format requirements

Fixes local model integration issues where models like qwen3:14b would fail
due to missing thought field in JSON output.

* feat: Add intelligent clarification feature for research queries

- Add multi-turn clarification process to refine vague research questions
- Implement three-dimension clarification standard (Tech/App, Focus, Scope)
- Add clarification state management in coordinator node
- Update coordinator prompt with detailed clarification guidelines
- Add UI settings to enable/disable clarification feature (disabled by default)
- Update workflow to handle clarification rounds recursively
- Add comprehensive test coverage for clarification functionality
- Update documentation with clarification feature usage guide

Key components:
- src/graph/nodes.py: Core clarification logic and state management
- src/prompts/coordinator.md: Detailed clarification guidelines
- src/workflow.py: Recursive clarification handling
- web/: UI settings integration
- tests/: Comprehensive test coverage
- docs/: Updated configuration guide

* fix: Improve clarification conversation continuity

- Add comprehensive conversation history to clarification context
- Include previous exchanges summary in system messages
- Add explicit guidelines for continuing rounds in coordinator prompt
- Prevent LLM from starting new topics during clarification
- Ensure topic continuity across clarification rounds

Fixes issue where LLM would restart clarification instead of building upon previous exchanges.

* fix: Add conversation history to clarification context

* fix: resolve clarification feature message to planer, prompt, test issues

- Optimize coordinator.md prompt template for better clarification flow
- Simplify final message sent to planner after clarification
- Fix API key assertion issues in test_search.py

* fix: Add configurable max_clarification_rounds and comprehensive tests

- Add max_clarification_rounds parameter for external configuration
- Add comprehensive test cases for clarification feature in test_app.py
- Fixes issues found during interactive mode testing where:
  - Recursive call failed due to missing initial_state parameter
  - Clarification exited prematurely at max rounds
  - Incorrect logging of max rounds reached

* Move clarification tests to test_nodes.py and add max_clarification_rounds to zh.json
2025-10-14 13:35:57 +08:00

1569 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from collections import namedtuple
from unittest.mock import MagicMock, patch
import pytest
from src.graph.nodes import (
_execute_agent_step,
_setup_and_execute_agent_step,
coordinator_node,
human_feedback_node,
planner_node,
reporter_node,
researcher_node,
)
# 在这里 mock 掉 get_llm_by_type避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from src.config import SearchEngine
from src.graph.nodes import background_investigation_node
# Mock data
MOCK_SEARCH_RESULTS = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
@pytest.fixture
def mock_state():
return {
"messages": [HumanMessage(content="test query")],
"research_topic": "test query",
"background_investigation_results": None,
}
@pytest.fixture
def mock_configurable():
mock = MagicMock()
mock.max_search_results = 7
return mock
@pytest.fixture
def mock_config():
# 你可以根据实际需要返回一个 MagicMock 或 dict
return MagicMock()
@pytest.fixture
def patch_config_from_runnable_config(mock_configurable):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
yield
@pytest.fixture
def mock_tavily_search():
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.fixture
def mock_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
def test_background_investigation_node_tavily(
mock_state,
mock_tavily_search,
mock_web_search_tool,
search_engine,
patch_config_from_runnable_config,
mock_config,
):
"""Test background_investigation_node with Tavily search engine"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, dict)
# Verify the update contains background_investigation_results
assert "background_investigation_results" in result
# Parse and verify the JSON content
results = result["background_investigation_results"]
if search_engine == SearchEngine.TAVILY.value:
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
assert (
results
== "## Test Title 1\n\nTest Content 1\n\n## Test Title 2\n\nTest Content 2"
)
else:
mock_web_search_tool.return_value.invoke.assert_called_once_with(
"test query"
)
assert len(json.loads(results)) == 2
def test_background_investigation_node_malformed_response(
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
):
"""Test background_investigation_node with malformed Tavily response"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
# Mock a malformed response
mock_tavily_search.return_value.invoke.return_value = "invalid response"
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, dict)
# Verify the update contains background_investigation_results
assert "background_investigation_results" in result
# Parse and verify the JSON content
results = result["background_investigation_results"]
assert json.loads(results) is None
@pytest.fixture
def mock_plan():
return {
"has_enough_context": True,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
@pytest.fixture
def mock_state_planner():
return {
"messages": [HumanMessage(content="plan this")],
"plan_iterations": 0,
"enable_background_investigation": True,
"background_investigation_results": "Background info",
}
@pytest.fixture
def mock_configurable_planner():
mock = MagicMock()
mock.max_plan_iterations = 3
mock.enable_deep_thinking = False
return mock
@pytest.fixture
def patch_config_from_runnable_config_planner(mock_configurable_planner):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_planner,
):
yield
@pytest.fixture
def patch_apply_prompt_template():
with patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "plan this"}],
) as mock:
yield mock
@pytest.fixture
def patch_repair_json_output():
with patch("src.graph.nodes.repair_json_output", side_effect=lambda x: x) as mock:
yield mock
@pytest.fixture
def patch_plan_model_validate():
with patch("src.graph.nodes.Plan.model_validate", side_effect=lambda x: x) as mock:
yield mock
@pytest.fixture
def patch_ai_message():
AIMessage = namedtuple("AIMessage", ["content", "name"])
with patch(
"src.graph.nodes.AIMessage",
side_effect=lambda content, name: AIMessage(content, name),
) as mock:
yield mock
def test_planner_node_basic_has_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
mock_plan,
):
# AGENT_LLM_MAP["planner"] == "basic" and not thinking mode
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = json.dumps(mock_plan)
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
assert "current_plan" in result.update
assert result.update["current_plan"]["has_enough_context"] is True
assert result.update["messages"][0].name == "planner"
def test_planner_node_basic_not_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
):
# AGENT_LLM_MAP["planner"] == "basic" and not thinking mode
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = json.dumps(plan)
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "human_feedback"
assert "current_plan" in result.update
assert isinstance(result.update["current_plan"], str)
assert result.update["messages"][0].name == "planner"
def test_planner_node_stream_mode_has_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
mock_plan,
):
# AGENT_LLM_MAP["planner"] != "basic"
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
# Simulate streaming chunks
chunk = MagicMock()
chunk.content = json.dumps(mock_plan)
mock_llm.stream.return_value = [chunk]
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
assert "current_plan" in result.update
assert result.update["current_plan"]["has_enough_context"] is True
def test_planner_node_stream_mode_not_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
):
# AGENT_LLM_MAP["planner"] != "basic"
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
chunk = MagicMock()
chunk.content = json.dumps(plan)
mock_llm.stream.return_value = [chunk]
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "human_feedback"
assert "current_plan" in result.update
assert isinstance(result.update["current_plan"], str)
def test_planner_node_plan_iterations_exceeded(mock_state_planner):
# plan_iterations >= max_plan_iterations
state = dict(mock_state_planner)
state["plan_iterations"] = 5
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type", return_value=MagicMock()),
):
result = planner_node(state, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
def test_planner_node_json_decode_error_first_iteration(mock_state_planner):
# Simulate JSONDecodeError on first iteration
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"bad": "json"'
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "__end__"
def test_planner_node_json_decode_error_second_iteration(mock_state_planner):
# Simulate JSONDecodeError on second iteration
state = dict(mock_state_planner)
state["plan_iterations"] = 1
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"bad": "json"'
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(state, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
# Patch Plan.model_validate and repair_json_output globally for these tests
@pytest.fixture(autouse=True)
def patch_plan_and_repair(monkeypatch):
monkeypatch.setattr("src.graph.nodes.Plan.model_validate", lambda x: x)
monkeypatch.setattr("src.graph.nodes.repair_json_output", lambda x: x)
yield
@pytest.fixture
def mock_state_base():
return {
"current_plan": json.dumps(
{
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
),
"plan_iterations": 0,
}
def test_human_feedback_node_auto_accepted(monkeypatch, mock_state_base):
# auto_accepted_plan True, should skip interrupt and parse plan
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_edit_plan(monkeypatch, mock_state_base):
# interrupt returns [EDIT_PLAN]..., should return Command to planner
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="[EDIT_PLAN] Please revise"):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "planner"
assert result.update["messages"][0].name == "feedback"
assert "[EDIT_PLAN]" in result.update["messages"][0].content
def test_human_feedback_node_accepted(monkeypatch, mock_state_base):
# interrupt returns [ACCEPTED]..., should proceed to parse plan
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="[ACCEPTED] Looks good!"):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base):
# interrupt returns something else, should raise TypeError
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="RANDOM_FEEDBACK"):
with pytest.raises(TypeError):
human_feedback_node(state)
def test_human_feedback_node_json_decode_error_first_iteration(
monkeypatch, mock_state_base
):
# repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations=0
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
with patch(
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0)
):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "__end__"
def test_human_feedback_node_json_decode_error_second_iteration(
monkeypatch, mock_state_base
):
# repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations>0
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 2
with patch(
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0)
):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "reporter"
def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base):
# Plan does not have enough context, should goto research_team
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
state = dict(mock_state_base)
state["current_plan"] = json.dumps(plan)
state["auto_accepted_plan"] = True
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
@pytest.fixture
def mock_state_coordinator():
return {
"messages": [{"role": "user", "content": "test"}],
"locale": "en-US",
"enable_clarification": False,
}
@pytest.fixture
def mock_configurable_coordinator():
mock = MagicMock()
mock.resources = ["resource1", "resource2"]
return mock
@pytest.fixture
def patch_config_from_runnable_config_coordinator(mock_configurable_coordinator):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_coordinator,
):
yield
@pytest.fixture
def patch_apply_prompt_template_coordinator():
with patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "test"}],
) as mock:
yield mock
@pytest.fixture
def patch_handoff_to_planner():
with patch("src.graph.nodes.handoff_to_planner", MagicMock()):
yield
@pytest.fixture
def patch_logger():
with patch("src.graph.nodes.logger") as mock_logger:
yield mock_logger
def make_mock_llm_response(tool_calls=None):
resp = MagicMock()
resp.tool_calls = tool_calls or []
return resp
def test_coordinator_node_no_tool_calls(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# No tool calls, should goto __end__
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response([])
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "__end__"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_planner(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# tool_calls present, should goto planner
tool_calls = [{"name": "handoff_to_planner", "args": {}}]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_background_investigator(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# enable_background_investigation True, should goto background_investigator
state = dict(mock_state_coordinator)
state["enable_background_investigation"] = True
tool_calls = [{"name": "handoff_to_planner", "args": {}}]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(state, MagicMock())
assert result.goto == "background_investigator"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_locale_override(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# tool_calls with locale in args should override locale
tool_calls = [
{
"name": "handoff_to_planner",
"args": {"locale": "zh-CN", "research_topic": "test topic"},
}
]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "zh-CN"
assert result.update["research_topic"] == "test topic"
assert result.update["resources"] == ["resource1", "resource2"]
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_tool_calls_exception_handling(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
# Simulate tool_call.get("args", {}) raising AttributeError
class BadToolCall(dict):
def get(self, key, default=None):
if key == "args":
raise Exception("bad args")
return super().get(key, default)
mock_llm.invoke.return_value = make_mock_llm_response(
[BadToolCall({"name": "handoff_to_planner"})]
)
mock_get_llm.return_value = mock_llm
# Should not raise, just log error and continue
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
@pytest.fixture
def mock_state_reporter():
# Simulate a plan object with title and thought attributes
Plan = namedtuple("Plan", ["title", "thought"])
return {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
"locale": "en-US",
"observations": [],
}
@pytest.fixture
def mock_state_reporter_with_observations():
Plan = namedtuple("Plan", ["title", "thought"])
return {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
"locale": "en-US",
"observations": ["Observation 1", "Observation 2"],
}
@pytest.fixture
def mock_configurable_reporter():
mock = MagicMock()
return mock
@pytest.fixture
def patch_config_from_runnable_config_reporter(mock_configurable_reporter):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_reporter,
):
yield
@pytest.fixture
def patch_apply_prompt_template_reporter():
with patch(
"src.graph.nodes.apply_prompt_template",
side_effect=lambda *args, **kwargs: [MagicMock()],
) as mock:
yield mock
@pytest.fixture
def patch_human_message():
HumanMessage = MagicMock()
with patch("src.graph.nodes.HumanMessage", HumanMessage):
yield HumanMessage
@pytest.fixture
def patch_logger_reporter():
with patch("src.graph.nodes.logger") as mock_logger:
yield mock_logger
def make_mock_llm_response_reporter(content):
resp = MagicMock()
resp.content = content
return resp
def test_reporter_node_basic(
mock_state_reporter,
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
# Patch get_llm_by_type and AGENT_LLM_MAP
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Final Report Content"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(mock_state_reporter, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Final Report Content"
# Should call apply_prompt_template with correct arguments
patch_apply_prompt_template_reporter.assert_called()
# Should call invoke on the LLM
mock_llm.invoke.assert_called()
def test_reporter_node_with_observations(
mock_state_reporter_with_observations,
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Report with Observations"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(mock_state_reporter_with_observations, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Report with Observations"
# Should call apply_prompt_template with correct arguments
patch_apply_prompt_template_reporter.assert_called()
# Should call invoke on the LLM
mock_llm.invoke.assert_called()
def test_reporter_node_locale_default(
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
# If locale is missing, should default to "en-US"
Plan = namedtuple("Plan", ["title", "thought"])
state = {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
# "locale" omitted
"observations": [],
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Default Locale Report"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(state, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Default Locale Report"
# Create the real Step class for the tests
class Step:
def __init__(self, title, description, execution_res=None):
self.title = title
self.description = description
self.execution_res = execution_res
@pytest.fixture
def mock_step():
return Step(title="Step 1", description="Desc 1", execution_res=None)
@pytest.fixture
def mock_completed_step():
return Step(title="Step 0", description="Desc 0", execution_res="Done")
@pytest.fixture
def mock_state_with_steps(mock_step, mock_completed_step):
# Simulate a plan with one completed and one unexecuted step
Plan = MagicMock()
Plan.steps = [mock_completed_step, mock_step]
return {
"current_plan": Plan,
"observations": ["obs1"],
"locale": "en-US",
"resources": [],
}
@pytest.fixture
def mock_state_no_unexecuted():
Step = namedtuple("Step", ["title", "description", "execution_res"])
Plan = MagicMock()
Plan.steps = [
Step(title="Step 1", description="Desc 1", execution_res="done"),
Step(title="Step 2", description="Desc 2", execution_res="done"),
]
return {
"current_plan": Plan,
"observations": [],
"locale": "en-US",
"resources": [],
}
@pytest.fixture
def mock_agent():
agent = MagicMock()
async def ainvoke(input, config):
# Simulate agent returning a message list
return {"messages": [MagicMock(content="result content")]}
agent.ainvoke = ainvoke
return agent
@pytest.mark.asyncio
async def test_execute_agent_step_basic(mock_state_with_steps, mock_agent):
# Should execute the first unexecuted step and update execution_res
with patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
):
result = await _execute_agent_step(
mock_state_with_steps, mock_agent, "researcher"
)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert "messages" in result.update
assert "observations" in result.update
# The new observation should be appended
assert result.update["observations"][-1] == "result content"
# The step's execution_res should be updated
assert (
mock_state_with_steps["current_plan"].steps[1].execution_res
== "result content"
)
@pytest.mark.asyncio
async def test_execute_agent_step_no_unexecuted_step(
mock_state_no_unexecuted, mock_agent
):
# Should return Command with goto="research_team" and not fail
with patch("src.graph.nodes.logger") as mock_logger:
result = await _execute_agent_step(
mock_state_no_unexecuted, mock_agent, "researcher"
)
assert isinstance(result, Command)
assert result.goto == "research_team"
mock_logger.warning.assert_called_with("No unexecuted step found")
@pytest.mark.asyncio
async def test_execute_agent_step_with_resources_and_researcher(mock_step):
# Should add resource info and citation reminder for researcher
Resource = namedtuple("Resource", ["title", "description"])
resources = [Resource(title="file1.txt", description="desc1")]
Plan = MagicMock()
Plan.steps = [mock_step]
state = {
"current_plan": Plan,
"observations": [],
"locale": "en-US",
"resources": resources,
}
agent = MagicMock()
async def ainvoke(input, config):
# Check that resource info and citation reminder are present
messages = input["messages"]
assert any("local_search_tool" in m.content for m in messages)
assert any("DO NOT include inline citations" in m.content for m in messages)
return {"messages": [MagicMock(content="resource result")]}
agent.ainvoke = ainvoke
with patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
):
result = await _execute_agent_step(state, agent, "researcher")
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["observations"][-1] == "resource result"
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should respect AGENT_RECURSION_LIMIT env variable if set and valid
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "42")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.info.assert_any_call("Recursion limit set to: 42")
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env_invalid(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should fallback to default if env variable is invalid
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "notanint")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.warning.assert_any_call(
"Invalid AGENT_RECURSION_LIMIT value: 'notanint'. Using default value 25."
)
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env_negative(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should fallback to default if env variable is negative or zero
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.warning.assert_any_call(
"AGENT_RECURSION_LIMIT value '-5' (parsed as -5) is not positive. Using default value 25."
)
@pytest.fixture
def mock_configurable_with_mcp():
mock = MagicMock()
mock.mcp_settings = {
"servers": {
"server1": {
"enabled_tools": ["toolA", "toolB"],
"add_to_agents": ["researcher"],
"transport": "http",
"command": "run",
"args": {},
"url": "http://localhost",
"env": {},
"other": "ignore",
}
}
}
return mock
@pytest.fixture
def mock_configurable_without_mcp():
mock = MagicMock()
mock.mcp_settings = None
return mock
@pytest.fixture
def patch_config_from_runnable_config_with_mcp(mock_configurable_with_mcp):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_with_mcp,
):
yield
@pytest.fixture
def patch_config_from_runnable_config_without_mcp(mock_configurable_without_mcp):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_without_mcp,
):
yield
@pytest.fixture
def patch_create_agent():
with patch("src.graph.nodes.create_agent") as mock:
yield mock
@pytest.fixture
def patch_execute_agent_step():
async def fake_execute_agent_step(state, agent, agent_type):
return "EXECUTED"
with patch(
"src.graph.nodes._execute_agent_step", side_effect=fake_execute_agent_step
) as mock:
yield mock
@pytest.fixture
def patch_multiserver_mcp_client():
# Patch MultiServerMCPClient as async context manager
class FakeTool:
def __init__(self, name, description="desc"):
self.name = name
self.description = description
class FakeClient:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
pass
async def get_tools(self):
return [
FakeTool("toolA", "descA"),
FakeTool("toolB", "descB"),
FakeTool("toolC", "descC"),
]
with patch(
"src.graph.nodes.MultiServerMCPClient", return_value=FakeClient()
) as mock:
yield mock
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_with_mcp,
patch_create_agent,
patch_execute_agent_step,
patch_multiserver_mcp_client,
):
# Should use MCP client, load tools, and call create_agent with correct tools
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# Should call create_agent with loaded_tools including toolA and toolB
args, kwargs = patch_create_agent.call_args
loaded_tools = args[2]
tool_names = [t.name for t in loaded_tools if hasattr(t, "name")]
assert "toolA" in tool_names
assert "toolB" in tool_names
# Should call _execute_agent_step
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_without_mcp(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_without_mcp,
patch_create_agent,
patch_execute_agent_step,
):
# Should use default tools and not use MCP client
default_tools = [MagicMock(name="default_tool")]
agent_type = "coder"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# Should call create_agent with default_tools
args, kwargs = patch_create_agent.call_args
assert args[2] == default_tools
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp_no_enabled_tools(
mock_state_with_steps,
mock_config,
patch_create_agent,
patch_execute_agent_step,
):
# If mcp_settings present but no enabled_tools for agent_type, should fallback to default_tools
mcp_settings = {
"servers": {
"server1": {
"enabled_tools": ["toolA"],
"add_to_agents": ["other_agent"],
"transport": "http",
"command": "run",
"args": {},
"url": "http://localhost",
"env": {},
}
}
}
configurable = MagicMock()
configurable.mcp_settings = mcp_settings
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=configurable,
):
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
args, kwargs = patch_create_agent.call_args
assert args[2] == default_tools
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp_tools_description_update(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_with_mcp,
patch_create_agent,
patch_execute_agent_step,
):
# Should update tool.description with Powered by info
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
# Patch MultiServerMCPClient to check description update
class FakeTool:
def __init__(self, name, description="desc"):
self.name = name
self.description = description
class FakeClient:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
pass
async def get_tools(self):
return [FakeTool("toolA", "descA")]
with patch("src.graph.nodes.MultiServerMCPClient", return_value=FakeClient()):
await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# The tool description should be updated
args, kwargs = patch_create_agent.call_args
loaded_tools = args[2]
found = False
for t in loaded_tools:
if hasattr(t, "name") and t.name == "toolA":
assert t.description.startswith("Powered by 'server1'.\n")
found = True
assert found
@pytest.fixture
def mock_state_with_resources():
return {"resources": ["resource1", "resource2"], "other": "value"}
@pytest.fixture
def mock_state_without_resources():
return {"other": "value"}
@pytest.fixture
def patch_get_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
mock_tool = MagicMock(name="web_search_tool")
mock.return_value = mock_tool
yield mock
@pytest.fixture
def patch_crawl_tool():
with patch("src.graph.nodes.crawl_tool", MagicMock(name="crawl_tool")):
yield
@pytest.fixture
def patch_get_retriever_tool():
with patch("src.graph.nodes.get_retriever_tool") as mock:
yield mock
@pytest.fixture
def patch_setup_and_execute_agent_step():
async def fake_setup_and_execute_agent_step(state, config, agent_type, tools):
return "RESEARCHER_RESULT"
with patch(
"src.graph.nodes._setup_and_execute_agent_step",
side_effect=fake_setup_and_execute_agent_step,
) as mock:
yield mock
@pytest.mark.asyncio
async def test_researcher_node_with_retriever_tool(
mock_state_with_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
# Simulate retriever_tool is returned
retriever_tool = MagicMock(name="retriever_tool")
patch_get_retriever_tool.return_value = retriever_tool
result = await researcher_node(mock_state_with_resources, mock_config)
# Should call get_web_search_tool with correct max_search_results
patch_get_web_search_tool.assert_called_once_with(7)
# Should call get_retriever_tool with resources
patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"])
# Should call _setup_and_execute_agent_step with retriever_tool first
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
assert tools[0] == retriever_tool
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"
@pytest.mark.asyncio
async def test_researcher_node_without_retriever_tool(
mock_state_with_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
# Simulate retriever_tool is None
patch_get_retriever_tool.return_value = None
result = await researcher_node(mock_state_with_resources, mock_config)
patch_get_web_search_tool.assert_called_once_with(7)
patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"])
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
# Should not include retriever_tool
assert all(getattr(t, "name", None) != "retriever_tool" for t in tools)
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"
@pytest.mark.asyncio
async def test_researcher_node_without_resources(
mock_state_without_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
patch_get_retriever_tool.return_value = None
result = await researcher_node(mock_state_without_resources, mock_config)
patch_get_web_search_tool.assert_called_once_with(7)
patch_get_retriever_tool.assert_called_once_with([])
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"
# ============================================================================
# Clarification Feature Tests
# ============================================================================
@pytest.mark.asyncio
async def test_clarification_workflow_integration():
"""Test the complete clarification workflow integration."""
import inspect
from src.workflow import run_agent_workflow_async
# Verify that the function accepts clarification parameters
sig = inspect.signature(run_agent_workflow_async)
assert "max_clarification_rounds" in sig.parameters
assert "enable_clarification" in sig.parameters
assert "initial_state" in sig.parameters
def test_clarification_parameters_combinations():
"""Test various combinations of clarification parameters."""
from src.graph.nodes import needs_clarification
test_cases = [
# (enable_clarification, clarification_rounds, max_rounds, is_complete, expected)
(True, 0, 3, False, False), # No rounds started
(True, 1, 3, False, True), # In progress
(True, 2, 3, False, True), # In progress
(True, 3, 3, False, True), # At max - still waiting for last answer
(True, 4, 3, False, False), # Exceeded max
(True, 1, 3, True, False), # Completed
(False, 1, 3, False, False), # Disabled
]
for enable, rounds, max_rounds, complete, expected in test_cases:
state = {
"enable_clarification": enable,
"clarification_rounds": rounds,
"max_clarification_rounds": max_rounds,
"is_clarification_complete": complete,
}
result = needs_clarification(state)
assert result == expected, f"Failed for case: {state}"
def test_handoff_tools():
"""Test that handoff tools are properly defined."""
from src.graph.nodes import handoff_after_clarification, handoff_to_planner
# Test handoff_to_planner tool - use invoke() method
result = handoff_to_planner.invoke(
{"research_topic": "renewable energy", "locale": "en-US"}
)
assert result is None # Tool should return None (no-op)
# Test handoff_after_clarification tool - use invoke() method
result = handoff_after_clarification.invoke({"locale": "en-US"})
assert result is None # Tool should return None (no-op)
@patch("src.graph.nodes.get_llm_by_type")
def test_coordinator_tools_with_clarification_enabled(mock_get_llm):
"""Test that coordinator binds correct tools when clarification is enabled."""
# Mock LLM response
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "Let me clarify..."
mock_response.tool_calls = []
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
# State with clarification enabled (in progress)
state = {
"messages": [{"role": "user", "content": "Tell me about something"}],
"enable_clarification": True,
"clarification_rounds": 2,
"max_clarification_rounds": 3,
"is_clarification_complete": False,
"clarification_history": ["response 1", "response 2"],
"locale": "en-US",
"research_topic": "",
}
# Mock config
config = {"configurable": {"resources": []}}
# Call coordinator_node
coordinator_node(state, config)
# Verify that LLM was called with bind_tools
assert mock_llm.bind_tools.called
bound_tools = mock_llm.bind_tools.call_args[0][0]
# Should bind 2 tools when clarification is enabled
assert len(bound_tools) == 2
tool_names = [tool.name for tool in bound_tools]
assert "handoff_to_planner" in tool_names
assert "handoff_after_clarification" in tool_names
@patch("src.graph.nodes.get_llm_by_type")
def test_coordinator_tools_with_clarification_disabled(mock_get_llm):
"""Test that coordinator binds only one tool when clarification is disabled."""
# Mock LLM response with tool call
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = ""
mock_response.tool_calls = [
{
"name": "handoff_to_planner",
"args": {"research_topic": "test", "locale": "en-US"},
}
]
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
# State with clarification disabled
state = {
"messages": [{"role": "user", "content": "Tell me about something"}],
"enable_clarification": False,
"locale": "en-US",
"research_topic": "",
}
# Mock config
config = {"configurable": {"resources": []}}
# Call coordinator_node
coordinator_node(state, config)
# Verify that LLM was called with bind_tools
assert mock_llm.bind_tools.called
bound_tools = mock_llm.bind_tools.call_args[0][0]
# Should bind only 1 tool when clarification is disabled
assert len(bound_tools) == 1
assert bound_tools[0].name == "handoff_to_planner"
@patch("src.graph.nodes.get_llm_by_type")
def test_coordinator_empty_llm_response_corner_case(mock_get_llm):
"""
Corner case test: LLM returns empty response when clarification is enabled.
This tests error handling when LLM fails to return any content or tool calls
in the initial state (clarification_rounds=0). The system should gracefully
handle this by going to __end__ instead of crashing.
Note: This is NOT a typical clarification workflow test, but rather tests
fault tolerance when LLM misbehaves.
"""
# Mock LLM response - empty response (failure scenario)
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = ""
mock_response.tool_calls = []
mock_llm.bind_tools.return_value.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
# State with clarification enabled but initial round
state = {
"messages": [{"role": "user", "content": "test"}],
"enable_clarification": True,
# clarification_rounds: 0 (default, not started)
"locale": "en-US",
"research_topic": "",
}
# Mock config
config = {"configurable": {"resources": []}}
# Call coordinator_node - should not crash
result = coordinator_node(state, config)
# Should gracefully handle empty response by going to __end__
assert result.goto == "__end__"
assert result.update["locale"] == "en-US"