Files
deer-flow/tests/integration/test_nodes.py

1383 lines
44 KiB
Python
Raw Permalink Normal View History

from collections import namedtuple
import json
import pytest
from unittest.mock import patch, MagicMock
from src.graph.nodes import planner_node
from src.graph.nodes import human_feedback_node
from src.graph.nodes import coordinator_node
from src.graph.nodes import reporter_node
from src.graph.nodes import _execute_agent_step
from src.graph.nodes import _setup_and_execute_agent_step
from src.graph.nodes import researcher_node
# 在这里 mock 掉 get_llm_by_type避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langgraph.types import Command
from src.graph.nodes import background_investigation_node
from src.config import SearchEngine
from langchain_core.messages import HumanMessage
# Mock data
MOCK_SEARCH_RESULTS = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
@pytest.fixture
def mock_state():
return {
"messages": [HumanMessage(content="test query")],
"research_topic": "test query",
"background_investigation_results": None,
}
@pytest.fixture
def mock_configurable():
mock = MagicMock()
mock.max_search_results = 7
return mock
@pytest.fixture
def mock_config():
# 你可以根据实际需要返回一个 MagicMock 或 dict
return MagicMock()
@pytest.fixture
def patch_config_from_runnable_config(mock_configurable):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
yield
@pytest.fixture
def mock_tavily_search():
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.fixture
def mock_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
def test_background_investigation_node_tavily(
mock_state,
mock_tavily_search,
mock_web_search_tool,
search_engine,
patch_config_from_runnable_config,
mock_config,
):
"""Test background_investigation_node with Tavily search engine"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, dict)
# Verify the update contains background_investigation_results
assert "background_investigation_results" in result
# Parse and verify the JSON content
results = result["background_investigation_results"]
if search_engine == SearchEngine.TAVILY.value:
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
assert (
results
== "## Test Title 1\n\nTest Content 1\n\n## Test Title 2\n\nTest Content 2"
)
else:
mock_web_search_tool.return_value.invoke.assert_called_once_with(
"test query"
)
assert len(json.loads(results)) == 2
def test_background_investigation_node_malformed_response(
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
):
"""Test background_investigation_node with malformed Tavily response"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
# Mock a malformed response
mock_tavily_search.return_value.invoke.return_value = "invalid response"
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, dict)
# Verify the update contains background_investigation_results
assert "background_investigation_results" in result
# Parse and verify the JSON content
results = result["background_investigation_results"]
assert json.loads(results) is None
@pytest.fixture
def mock_plan():
return {
"has_enough_context": True,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
@pytest.fixture
def mock_state_planner():
return {
"messages": [HumanMessage(content="plan this")],
"plan_iterations": 0,
"enable_background_investigation": True,
"background_investigation_results": "Background info",
}
@pytest.fixture
def mock_configurable_planner():
mock = MagicMock()
mock.max_plan_iterations = 3
mock.enable_deep_thinking = False
return mock
@pytest.fixture
def patch_config_from_runnable_config_planner(mock_configurable_planner):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_planner,
):
yield
@pytest.fixture
def patch_apply_prompt_template():
with patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "plan this"}],
) as mock:
yield mock
@pytest.fixture
def patch_repair_json_output():
with patch("src.graph.nodes.repair_json_output", side_effect=lambda x: x) as mock:
yield mock
@pytest.fixture
def patch_plan_model_validate():
with patch("src.graph.nodes.Plan.model_validate", side_effect=lambda x: x) as mock:
yield mock
@pytest.fixture
def patch_ai_message():
AIMessage = namedtuple("AIMessage", ["content", "name"])
with patch(
"src.graph.nodes.AIMessage",
side_effect=lambda content, name: AIMessage(content, name),
) as mock:
yield mock
def test_planner_node_basic_has_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
mock_plan,
):
# AGENT_LLM_MAP["planner"] == "basic" and not thinking mode
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = json.dumps(mock_plan)
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
assert "current_plan" in result.update
assert result.update["current_plan"]["has_enough_context"] is True
assert result.update["messages"][0].name == "planner"
def test_planner_node_basic_not_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
):
# AGENT_LLM_MAP["planner"] == "basic" and not thinking mode
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = json.dumps(plan)
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "human_feedback"
assert "current_plan" in result.update
assert isinstance(result.update["current_plan"], str)
assert result.update["messages"][0].name == "planner"
def test_planner_node_stream_mode_has_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
mock_plan,
):
# AGENT_LLM_MAP["planner"] != "basic"
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
# Simulate streaming chunks
chunk = MagicMock()
chunk.content = json.dumps(mock_plan)
mock_llm.stream.return_value = [chunk]
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
assert "current_plan" in result.update
assert result.update["current_plan"]["has_enough_context"] is True
def test_planner_node_stream_mode_not_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
):
# AGENT_LLM_MAP["planner"] != "basic"
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
chunk = MagicMock()
chunk.content = json.dumps(plan)
mock_llm.stream.return_value = [chunk]
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "human_feedback"
assert "current_plan" in result.update
assert isinstance(result.update["current_plan"], str)
def test_planner_node_plan_iterations_exceeded(mock_state_planner):
# plan_iterations >= max_plan_iterations
state = dict(mock_state_planner)
state["plan_iterations"] = 5
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type", return_value=MagicMock()),
):
result = planner_node(state, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
def test_planner_node_json_decode_error_first_iteration(mock_state_planner):
# Simulate JSONDecodeError on first iteration
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"bad": "json"'
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "__end__"
def test_planner_node_json_decode_error_second_iteration(mock_state_planner):
# Simulate JSONDecodeError on second iteration
state = dict(mock_state_planner)
state["plan_iterations"] = 1
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"bad": "json"'
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(state, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
# Patch Plan.model_validate and repair_json_output globally for these tests
@pytest.fixture(autouse=True)
def patch_plan_and_repair(monkeypatch):
monkeypatch.setattr("src.graph.nodes.Plan.model_validate", lambda x: x)
monkeypatch.setattr("src.graph.nodes.repair_json_output", lambda x: x)
yield
@pytest.fixture
def mock_state_base():
return {
"current_plan": json.dumps(
{
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
),
"plan_iterations": 0,
}
def test_human_feedback_node_auto_accepted(monkeypatch, mock_state_base):
# auto_accepted_plan True, should skip interrupt and parse plan
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_edit_plan(monkeypatch, mock_state_base):
# interrupt returns [EDIT_PLAN]..., should return Command to planner
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="[EDIT_PLAN] Please revise"):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "planner"
assert result.update["messages"][0].name == "feedback"
assert "[EDIT_PLAN]" in result.update["messages"][0].content
def test_human_feedback_node_accepted(monkeypatch, mock_state_base):
# interrupt returns [ACCEPTED]..., should proceed to parse plan
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="[ACCEPTED] Looks good!"):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base):
# interrupt returns something else, should raise TypeError
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="RANDOM_FEEDBACK"):
with pytest.raises(TypeError):
human_feedback_node(state)
def test_human_feedback_node_json_decode_error_first_iteration(
monkeypatch, mock_state_base
):
# repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations=0
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
with patch(
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0)
):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "__end__"
def test_human_feedback_node_json_decode_error_second_iteration(
monkeypatch, mock_state_base
):
# repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations>0
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 2
with patch(
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0)
):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "reporter"
def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base):
# Plan does not have enough context, should goto research_team
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
state = dict(mock_state_base)
state["current_plan"] = json.dumps(plan)
state["auto_accepted_plan"] = True
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
@pytest.fixture
def mock_state_coordinator():
return {
"messages": [{"role": "user", "content": "test"}],
"locale": "en-US",
}
@pytest.fixture
def mock_configurable_coordinator():
mock = MagicMock()
mock.resources = ["resource1", "resource2"]
return mock
@pytest.fixture
def patch_config_from_runnable_config_coordinator(mock_configurable_coordinator):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_coordinator,
):
yield
@pytest.fixture
def patch_apply_prompt_template_coordinator():
with patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "test"}],
) as mock:
yield mock
@pytest.fixture
def patch_handoff_to_planner():
with patch("src.graph.nodes.handoff_to_planner", MagicMock()):
yield
@pytest.fixture
def patch_logger():
with patch("src.graph.nodes.logger") as mock_logger:
yield mock_logger
def make_mock_llm_response(tool_calls=None):
resp = MagicMock()
resp.tool_calls = tool_calls or []
return resp
def test_coordinator_node_no_tool_calls(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# No tool calls, should goto __end__
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response([])
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "__end__"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_planner(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# tool_calls present, should goto planner
tool_calls = [{"name": "handoff_to_planner", "args": {}}]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_background_investigator(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# enable_background_investigation True, should goto background_investigator
state = dict(mock_state_coordinator)
state["enable_background_investigation"] = True
tool_calls = [{"name": "handoff_to_planner", "args": {}}]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(state, MagicMock())
assert result.goto == "background_investigator"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_locale_override(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# tool_calls with locale in args should override locale
tool_calls = [
{
"name": "handoff_to_planner",
"args": {"locale": "zh-CN", "research_topic": "test topic"},
}
]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "zh-CN"
assert result.update["research_topic"] == "test topic"
assert result.update["resources"] == ["resource1", "resource2"]
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_tool_calls_exception_handling(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
# Simulate tool_call.get("args", {}) raising AttributeError
class BadToolCall(dict):
def get(self, key, default=None):
if key == "args":
raise Exception("bad args")
return super().get(key, default)
mock_llm.invoke.return_value = make_mock_llm_response(
[BadToolCall({"name": "handoff_to_planner"})]
)
mock_get_llm.return_value = mock_llm
# Should not raise, just log error and continue
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
@pytest.fixture
def mock_state_reporter():
# Simulate a plan object with title and thought attributes
Plan = namedtuple("Plan", ["title", "thought"])
return {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
"locale": "en-US",
"observations": [],
}
@pytest.fixture
def mock_state_reporter_with_observations():
Plan = namedtuple("Plan", ["title", "thought"])
return {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
"locale": "en-US",
"observations": ["Observation 1", "Observation 2"],
}
@pytest.fixture
def mock_configurable_reporter():
mock = MagicMock()
return mock
@pytest.fixture
def patch_config_from_runnable_config_reporter(mock_configurable_reporter):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_reporter,
):
yield
@pytest.fixture
def patch_apply_prompt_template_reporter():
with patch(
"src.graph.nodes.apply_prompt_template",
side_effect=lambda *args, **kwargs: [MagicMock()],
) as mock:
yield mock
@pytest.fixture
def patch_human_message():
HumanMessage = MagicMock()
with patch("src.graph.nodes.HumanMessage", HumanMessage):
yield HumanMessage
@pytest.fixture
def patch_logger_reporter():
with patch("src.graph.nodes.logger") as mock_logger:
yield mock_logger
def make_mock_llm_response_reporter(content):
resp = MagicMock()
resp.content = content
return resp
def test_reporter_node_basic(
mock_state_reporter,
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
# Patch get_llm_by_type and AGENT_LLM_MAP
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Final Report Content"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(mock_state_reporter, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Final Report Content"
# Should call apply_prompt_template with correct arguments
patch_apply_prompt_template_reporter.assert_called()
# Should call invoke on the LLM
mock_llm.invoke.assert_called()
def test_reporter_node_with_observations(
mock_state_reporter_with_observations,
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Report with Observations"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(mock_state_reporter_with_observations, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Report with Observations"
# Should call apply_prompt_template with correct arguments
patch_apply_prompt_template_reporter.assert_called()
# Should call invoke on the LLM
mock_llm.invoke.assert_called()
def test_reporter_node_locale_default(
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
# If locale is missing, should default to "en-US"
Plan = namedtuple("Plan", ["title", "thought"])
state = {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
# "locale" omitted
"observations": [],
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Default Locale Report"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(state, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Default Locale Report"
# Create the real Step class for the tests
class Step:
def __init__(self, title, description, execution_res=None):
self.title = title
self.description = description
self.execution_res = execution_res
@pytest.fixture
def mock_step():
return Step(title="Step 1", description="Desc 1", execution_res=None)
@pytest.fixture
def mock_completed_step():
return Step(title="Step 0", description="Desc 0", execution_res="Done")
@pytest.fixture
def mock_state_with_steps(mock_step, mock_completed_step):
# Simulate a plan with one completed and one unexecuted step
Plan = MagicMock()
Plan.steps = [mock_completed_step, mock_step]
return {
"current_plan": Plan,
"observations": ["obs1"],
"locale": "en-US",
"resources": [],
}
@pytest.fixture
def mock_state_no_unexecuted():
Step = namedtuple("Step", ["title", "description", "execution_res"])
Plan = MagicMock()
Plan.steps = [
Step(title="Step 1", description="Desc 1", execution_res="done"),
Step(title="Step 2", description="Desc 2", execution_res="done"),
]
return {
"current_plan": Plan,
"observations": [],
"locale": "en-US",
"resources": [],
}
@pytest.fixture
def mock_agent():
agent = MagicMock()
async def ainvoke(input, config):
# Simulate agent returning a message list
return {"messages": [MagicMock(content="result content")]}
agent.ainvoke = ainvoke
return agent
@pytest.mark.asyncio
async def test_execute_agent_step_basic(mock_state_with_steps, mock_agent):
# Should execute the first unexecuted step and update execution_res
with patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
):
result = await _execute_agent_step(
mock_state_with_steps, mock_agent, "researcher"
)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert "messages" in result.update
assert "observations" in result.update
# The new observation should be appended
assert result.update["observations"][-1] == "result content"
# The step's execution_res should be updated
assert (
mock_state_with_steps["current_plan"].steps[1].execution_res
== "result content"
)
@pytest.mark.asyncio
async def test_execute_agent_step_no_unexecuted_step(
mock_state_no_unexecuted, mock_agent
):
# Should return Command with goto="research_team" and not fail
with patch("src.graph.nodes.logger") as mock_logger:
result = await _execute_agent_step(
mock_state_no_unexecuted, mock_agent, "researcher"
)
assert isinstance(result, Command)
assert result.goto == "research_team"
mock_logger.warning.assert_called_with("No unexecuted step found")
@pytest.mark.asyncio
async def test_execute_agent_step_with_resources_and_researcher(mock_step):
# Should add resource info and citation reminder for researcher
Resource = namedtuple("Resource", ["title", "description"])
resources = [Resource(title="file1.txt", description="desc1")]
Plan = MagicMock()
Plan.steps = [mock_step]
state = {
"current_plan": Plan,
"observations": [],
"locale": "en-US",
"resources": resources,
}
agent = MagicMock()
async def ainvoke(input, config):
# Check that resource info and citation reminder are present
messages = input["messages"]
assert any("local_search_tool" in m.content for m in messages)
assert any("DO NOT include inline citations" in m.content for m in messages)
return {"messages": [MagicMock(content="resource result")]}
agent.ainvoke = ainvoke
with patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
):
result = await _execute_agent_step(state, agent, "researcher")
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["observations"][-1] == "resource result"
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should respect AGENT_RECURSION_LIMIT env variable if set and valid
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "42")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.info.assert_any_call("Recursion limit set to: 42")
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env_invalid(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should fallback to default if env variable is invalid
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "notanint")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.warning.assert_any_call(
"Invalid AGENT_RECURSION_LIMIT value: 'notanint'. Using default value 25."
)
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env_negative(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should fallback to default if env variable is negative or zero
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.warning.assert_any_call(
"AGENT_RECURSION_LIMIT value '-5' (parsed as -5) is not positive. Using default value 25."
)
@pytest.fixture
def mock_configurable_with_mcp():
mock = MagicMock()
mock.mcp_settings = {
"servers": {
"server1": {
"enabled_tools": ["toolA", "toolB"],
"add_to_agents": ["researcher"],
"transport": "http",
"command": "run",
"args": {},
"url": "http://localhost",
"env": {},
"other": "ignore",
}
}
}
return mock
@pytest.fixture
def mock_configurable_without_mcp():
mock = MagicMock()
mock.mcp_settings = None
return mock
@pytest.fixture
def patch_config_from_runnable_config_with_mcp(mock_configurable_with_mcp):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_with_mcp,
):
yield
@pytest.fixture
def patch_config_from_runnable_config_without_mcp(mock_configurable_without_mcp):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_without_mcp,
):
yield
@pytest.fixture
def patch_create_agent():
with patch("src.graph.nodes.create_agent") as mock:
yield mock
@pytest.fixture
def patch_execute_agent_step():
async def fake_execute_agent_step(state, agent, agent_type):
return "EXECUTED"
with patch(
"src.graph.nodes._execute_agent_step", side_effect=fake_execute_agent_step
) as mock:
yield mock
@pytest.fixture
def patch_multiserver_mcp_client():
# Patch MultiServerMCPClient as async context manager
class FakeTool:
def __init__(self, name, description="desc"):
self.name = name
self.description = description
class FakeClient:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
pass
def get_tools(self):
return [
FakeTool("toolA", "descA"),
FakeTool("toolB", "descB"),
FakeTool("toolC", "descC"),
]
with patch(
"src.graph.nodes.MultiServerMCPClient", return_value=FakeClient()
) as mock:
yield mock
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_with_mcp,
patch_create_agent,
patch_execute_agent_step,
patch_multiserver_mcp_client,
):
# Should use MCP client, load tools, and call create_agent with correct tools
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# Should call create_agent with loaded_tools including toolA and toolB
args, kwargs = patch_create_agent.call_args
loaded_tools = args[2]
tool_names = [t.name for t in loaded_tools if hasattr(t, "name")]
assert "toolA" in tool_names
assert "toolB" in tool_names
# Should call _execute_agent_step
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_without_mcp(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_without_mcp,
patch_create_agent,
patch_execute_agent_step,
):
# Should use default tools and not use MCP client
default_tools = [MagicMock(name="default_tool")]
agent_type = "coder"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# Should call create_agent with default_tools
args, kwargs = patch_create_agent.call_args
assert args[2] == default_tools
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp_no_enabled_tools(
mock_state_with_steps,
mock_config,
patch_create_agent,
patch_execute_agent_step,
):
# If mcp_settings present but no enabled_tools for agent_type, should fallback to default_tools
mcp_settings = {
"servers": {
"server1": {
"enabled_tools": ["toolA"],
"add_to_agents": ["other_agent"],
"transport": "http",
"command": "run",
"args": {},
"url": "http://localhost",
"env": {},
}
}
}
configurable = MagicMock()
configurable.mcp_settings = mcp_settings
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=configurable,
):
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
args, kwargs = patch_create_agent.call_args
assert args[2] == default_tools
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp_tools_description_update(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_with_mcp,
patch_create_agent,
patch_execute_agent_step,
):
# Should update tool.description with Powered by info
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
# Patch MultiServerMCPClient to check description update
class FakeTool:
def __init__(self, name, description="desc"):
self.name = name
self.description = description
class FakeClient:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
pass
def get_tools(self):
return [FakeTool("toolA", "descA")]
with patch("src.graph.nodes.MultiServerMCPClient", return_value=FakeClient()):
await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# The tool description should be updated
args, kwargs = patch_create_agent.call_args
loaded_tools = args[2]
found = False
for t in loaded_tools:
if hasattr(t, "name") and t.name == "toolA":
assert t.description.startswith("Powered by 'server1'.\n")
found = True
assert found
@pytest.fixture
def mock_state_with_resources():
return {"resources": ["resource1", "resource2"], "other": "value"}
@pytest.fixture
def mock_state_without_resources():
return {"other": "value"}
@pytest.fixture
def patch_get_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
mock_tool = MagicMock(name="web_search_tool")
mock.return_value = mock_tool
yield mock
@pytest.fixture
def patch_crawl_tool():
with patch("src.graph.nodes.crawl_tool", MagicMock(name="crawl_tool")):
yield
@pytest.fixture
def patch_get_retriever_tool():
with patch("src.graph.nodes.get_retriever_tool") as mock:
yield mock
@pytest.fixture
def patch_setup_and_execute_agent_step():
async def fake_setup_and_execute_agent_step(state, config, agent_type, tools):
return "RESEARCHER_RESULT"
with patch(
"src.graph.nodes._setup_and_execute_agent_step",
side_effect=fake_setup_and_execute_agent_step,
) as mock:
yield mock
@pytest.mark.asyncio
async def test_researcher_node_with_retriever_tool(
mock_state_with_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
# Simulate retriever_tool is returned
retriever_tool = MagicMock(name="retriever_tool")
patch_get_retriever_tool.return_value = retriever_tool
result = await researcher_node(mock_state_with_resources, mock_config)
# Should call get_web_search_tool with correct max_search_results
patch_get_web_search_tool.assert_called_once_with(7)
# Should call get_retriever_tool with resources
patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"])
# Should call _setup_and_execute_agent_step with retriever_tool first
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
assert tools[0] == retriever_tool
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"
@pytest.mark.asyncio
async def test_researcher_node_without_retriever_tool(
mock_state_with_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
# Simulate retriever_tool is None
patch_get_retriever_tool.return_value = None
result = await researcher_node(mock_state_with_resources, mock_config)
patch_get_web_search_tool.assert_called_once_with(7)
patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"])
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
# Should not include retriever_tool
assert all(getattr(t, "name", None) != "retriever_tool" for t in tools)
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"
@pytest.mark.asyncio
async def test_researcher_node_without_resources(
mock_state_without_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
patch_get_retriever_tool.return_value = None
result = await researcher_node(mock_state_without_resources, mock_config)
patch_get_web_search_tool.assert_called_once_with(7)
patch_get_retriever_tool.assert_called_once_with([])
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"