Files
deer-flow/tests/integration/test_nodes.py
zgjja 3b4e993531 feat: 1. replace black with ruff for fomatting and sort import (#489)
2. use tavily from`langchain-tavily` rather than the older one from `langchain-community`

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2025-08-17 22:57:23 +08:00

1388 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from collections import namedtuple
from unittest.mock import MagicMock, patch
import pytest
from src.graph.nodes import (
_execute_agent_step,
_setup_and_execute_agent_step,
coordinator_node,
human_feedback_node,
planner_node,
reporter_node,
researcher_node,
)
# 在这里 mock 掉 get_llm_by_type避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from src.config import SearchEngine
from src.graph.nodes import background_investigation_node
# Mock data
MOCK_SEARCH_RESULTS = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
@pytest.fixture
def mock_state():
return {
"messages": [HumanMessage(content="test query")],
"research_topic": "test query",
"background_investigation_results": None,
}
@pytest.fixture
def mock_configurable():
mock = MagicMock()
mock.max_search_results = 7
return mock
@pytest.fixture
def mock_config():
# 你可以根据实际需要返回一个 MagicMock 或 dict
return MagicMock()
@pytest.fixture
def patch_config_from_runnable_config(mock_configurable):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
yield
@pytest.fixture
def mock_tavily_search():
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.fixture
def mock_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
def test_background_investigation_node_tavily(
mock_state,
mock_tavily_search,
mock_web_search_tool,
search_engine,
patch_config_from_runnable_config,
mock_config,
):
"""Test background_investigation_node with Tavily search engine"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, dict)
# Verify the update contains background_investigation_results
assert "background_investigation_results" in result
# Parse and verify the JSON content
results = result["background_investigation_results"]
if search_engine == SearchEngine.TAVILY.value:
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
assert (
results
== "## Test Title 1\n\nTest Content 1\n\n## Test Title 2\n\nTest Content 2"
)
else:
mock_web_search_tool.return_value.invoke.assert_called_once_with(
"test query"
)
assert len(json.loads(results)) == 2
def test_background_investigation_node_malformed_response(
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
):
"""Test background_investigation_node with malformed Tavily response"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
# Mock a malformed response
mock_tavily_search.return_value.invoke.return_value = "invalid response"
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, dict)
# Verify the update contains background_investigation_results
assert "background_investigation_results" in result
# Parse and verify the JSON content
results = result["background_investigation_results"]
assert json.loads(results) is None
@pytest.fixture
def mock_plan():
return {
"has_enough_context": True,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
@pytest.fixture
def mock_state_planner():
return {
"messages": [HumanMessage(content="plan this")],
"plan_iterations": 0,
"enable_background_investigation": True,
"background_investigation_results": "Background info",
}
@pytest.fixture
def mock_configurable_planner():
mock = MagicMock()
mock.max_plan_iterations = 3
mock.enable_deep_thinking = False
return mock
@pytest.fixture
def patch_config_from_runnable_config_planner(mock_configurable_planner):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_planner,
):
yield
@pytest.fixture
def patch_apply_prompt_template():
with patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "plan this"}],
) as mock:
yield mock
@pytest.fixture
def patch_repair_json_output():
with patch("src.graph.nodes.repair_json_output", side_effect=lambda x: x) as mock:
yield mock
@pytest.fixture
def patch_plan_model_validate():
with patch("src.graph.nodes.Plan.model_validate", side_effect=lambda x: x) as mock:
yield mock
@pytest.fixture
def patch_ai_message():
AIMessage = namedtuple("AIMessage", ["content", "name"])
with patch(
"src.graph.nodes.AIMessage",
side_effect=lambda content, name: AIMessage(content, name),
) as mock:
yield mock
def test_planner_node_basic_has_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
mock_plan,
):
# AGENT_LLM_MAP["planner"] == "basic" and not thinking mode
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = json.dumps(mock_plan)
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
assert "current_plan" in result.update
assert result.update["current_plan"]["has_enough_context"] is True
assert result.update["messages"][0].name == "planner"
def test_planner_node_basic_not_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
):
# AGENT_LLM_MAP["planner"] == "basic" and not thinking mode
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = json.dumps(plan)
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "human_feedback"
assert "current_plan" in result.update
assert isinstance(result.update["current_plan"], str)
assert result.update["messages"][0].name == "planner"
def test_planner_node_stream_mode_has_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
mock_plan,
):
# AGENT_LLM_MAP["planner"] != "basic"
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
# Simulate streaming chunks
chunk = MagicMock()
chunk.content = json.dumps(mock_plan)
mock_llm.stream.return_value = [chunk]
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
assert "current_plan" in result.update
assert result.update["current_plan"]["has_enough_context"] is True
def test_planner_node_stream_mode_not_enough_context(
mock_state_planner,
patch_config_from_runnable_config_planner,
patch_apply_prompt_template,
patch_repair_json_output,
patch_plan_model_validate,
patch_ai_message,
):
# AGENT_LLM_MAP["planner"] != "basic"
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
chunk = MagicMock()
chunk.content = json.dumps(plan)
mock_llm.stream.return_value = [chunk]
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "human_feedback"
assert "current_plan" in result.update
assert isinstance(result.update["current_plan"], str)
def test_planner_node_plan_iterations_exceeded(mock_state_planner):
# plan_iterations >= max_plan_iterations
state = dict(mock_state_planner)
state["plan_iterations"] = 5
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type", return_value=MagicMock()),
):
result = planner_node(state, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
def test_planner_node_json_decode_error_first_iteration(mock_state_planner):
# Simulate JSONDecodeError on first iteration
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"bad": "json"'
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(mock_state_planner, MagicMock())
assert isinstance(result, Command)
assert result.goto == "__end__"
def test_planner_node_json_decode_error_second_iteration(mock_state_planner):
# Simulate JSONDecodeError on second iteration
state = dict(mock_state_planner)
state["plan_iterations"] = 1
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
):
mock_llm = MagicMock()
mock_llm.with_structured_output.return_value = mock_llm
mock_response = MagicMock()
mock_response.model_dump_json.return_value = '{"bad": "json"'
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = planner_node(state, MagicMock())
assert isinstance(result, Command)
assert result.goto == "reporter"
# Patch Plan.model_validate and repair_json_output globally for these tests
@pytest.fixture(autouse=True)
def patch_plan_and_repair(monkeypatch):
monkeypatch.setattr("src.graph.nodes.Plan.model_validate", lambda x: x)
monkeypatch.setattr("src.graph.nodes.repair_json_output", lambda x: x)
yield
@pytest.fixture
def mock_state_base():
return {
"current_plan": json.dumps(
{
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
),
"plan_iterations": 0,
}
def test_human_feedback_node_auto_accepted(monkeypatch, mock_state_base):
# auto_accepted_plan True, should skip interrupt and parse plan
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_edit_plan(monkeypatch, mock_state_base):
# interrupt returns [EDIT_PLAN]..., should return Command to planner
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="[EDIT_PLAN] Please revise"):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "planner"
assert result.update["messages"][0].name == "feedback"
assert "[EDIT_PLAN]" in result.update["messages"][0].content
def test_human_feedback_node_accepted(monkeypatch, mock_state_base):
# interrupt returns [ACCEPTED]..., should proceed to parse plan
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="[ACCEPTED] Looks good!"):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base):
# interrupt returns something else, should raise TypeError
state = dict(mock_state_base)
state["auto_accepted_plan"] = False
with patch("src.graph.nodes.interrupt", return_value="RANDOM_FEEDBACK"):
with pytest.raises(TypeError):
human_feedback_node(state)
def test_human_feedback_node_json_decode_error_first_iteration(
monkeypatch, mock_state_base
):
# repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations=0
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
with patch(
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0)
):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "__end__"
def test_human_feedback_node_json_decode_error_second_iteration(
monkeypatch, mock_state_base
):
# repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations>0
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 2
with patch(
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0)
):
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "reporter"
def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base):
# Plan does not have enough context, should goto research_team
plan = {
"has_enough_context": False,
"title": "Test Plan",
"thought": "Test Thought",
"steps": [],
"locale": "en-US",
}
state = dict(mock_state_base)
state["current_plan"] = json.dumps(plan)
state["auto_accepted_plan"] = True
result = human_feedback_node(state)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
assert result.update["current_plan"]["has_enough_context"] is False
@pytest.fixture
def mock_state_coordinator():
return {
"messages": [{"role": "user", "content": "test"}],
"locale": "en-US",
}
@pytest.fixture
def mock_configurable_coordinator():
mock = MagicMock()
mock.resources = ["resource1", "resource2"]
return mock
@pytest.fixture
def patch_config_from_runnable_config_coordinator(mock_configurable_coordinator):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_coordinator,
):
yield
@pytest.fixture
def patch_apply_prompt_template_coordinator():
with patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "test"}],
) as mock:
yield mock
@pytest.fixture
def patch_handoff_to_planner():
with patch("src.graph.nodes.handoff_to_planner", MagicMock()):
yield
@pytest.fixture
def patch_logger():
with patch("src.graph.nodes.logger") as mock_logger:
yield mock_logger
def make_mock_llm_response(tool_calls=None):
resp = MagicMock()
resp.tool_calls = tool_calls or []
return resp
def test_coordinator_node_no_tool_calls(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# No tool calls, should goto __end__
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response([])
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "__end__"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_planner(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# tool_calls present, should goto planner
tool_calls = [{"name": "handoff_to_planner", "args": {}}]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_background_investigator(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# enable_background_investigation True, should goto background_investigator
state = dict(mock_state_coordinator)
state["enable_background_investigation"] = True
tool_calls = [{"name": "handoff_to_planner", "args": {}}]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(state, MagicMock())
assert result.goto == "background_investigator"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_with_tool_calls_locale_override(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
# tool_calls with locale in args should override locale
tool_calls = [
{
"name": "handoff_to_planner",
"args": {"locale": "zh-CN", "research_topic": "test topic"},
}
]
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
mock_llm.invoke.return_value = make_mock_llm_response(tool_calls)
mock_get_llm.return_value = mock_llm
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "zh-CN"
assert result.update["research_topic"] == "test topic"
assert result.update["resources"] == ["resource1", "resource2"]
assert result.update["resources"] == ["resource1", "resource2"]
def test_coordinator_node_tool_calls_exception_handling(
mock_state_coordinator,
patch_config_from_runnable_config_coordinator,
patch_apply_prompt_template_coordinator,
patch_handoff_to_planner,
patch_logger,
):
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.bind_tools.return_value = mock_llm
# Simulate tool_call.get("args", {}) raising AttributeError
class BadToolCall(dict):
def get(self, key, default=None):
if key == "args":
raise Exception("bad args")
return super().get(key, default)
mock_llm.invoke.return_value = make_mock_llm_response(
[BadToolCall({"name": "handoff_to_planner"})]
)
mock_get_llm.return_value = mock_llm
# Should not raise, just log error and continue
result = coordinator_node(mock_state_coordinator, MagicMock())
assert result.goto == "planner"
assert result.update["locale"] == "en-US"
assert result.update["resources"] == ["resource1", "resource2"]
@pytest.fixture
def mock_state_reporter():
# Simulate a plan object with title and thought attributes
Plan = namedtuple("Plan", ["title", "thought"])
return {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
"locale": "en-US",
"observations": [],
}
@pytest.fixture
def mock_state_reporter_with_observations():
Plan = namedtuple("Plan", ["title", "thought"])
return {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
"locale": "en-US",
"observations": ["Observation 1", "Observation 2"],
}
@pytest.fixture
def mock_configurable_reporter():
mock = MagicMock()
return mock
@pytest.fixture
def patch_config_from_runnable_config_reporter(mock_configurable_reporter):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_reporter,
):
yield
@pytest.fixture
def patch_apply_prompt_template_reporter():
with patch(
"src.graph.nodes.apply_prompt_template",
side_effect=lambda *args, **kwargs: [MagicMock()],
) as mock:
yield mock
@pytest.fixture
def patch_human_message():
HumanMessage = MagicMock()
with patch("src.graph.nodes.HumanMessage", HumanMessage):
yield HumanMessage
@pytest.fixture
def patch_logger_reporter():
with patch("src.graph.nodes.logger") as mock_logger:
yield mock_logger
def make_mock_llm_response_reporter(content):
resp = MagicMock()
resp.content = content
return resp
def test_reporter_node_basic(
mock_state_reporter,
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
# Patch get_llm_by_type and AGENT_LLM_MAP
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Final Report Content"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(mock_state_reporter, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Final Report Content"
# Should call apply_prompt_template with correct arguments
patch_apply_prompt_template_reporter.assert_called()
# Should call invoke on the LLM
mock_llm.invoke.assert_called()
def test_reporter_node_with_observations(
mock_state_reporter_with_observations,
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Report with Observations"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(mock_state_reporter_with_observations, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Report with Observations"
# Should call apply_prompt_template with correct arguments
patch_apply_prompt_template_reporter.assert_called()
# Should call invoke on the LLM
mock_llm.invoke.assert_called()
def test_reporter_node_locale_default(
patch_config_from_runnable_config_reporter,
patch_apply_prompt_template_reporter,
patch_human_message,
patch_logger_reporter,
):
# If locale is missing, should default to "en-US"
Plan = namedtuple("Plan", ["title", "thought"])
state = {
"current_plan": Plan(title="Test Title", thought="Test Thought"),
# "locale" omitted
"observations": [],
}
with (
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
):
mock_llm = MagicMock()
mock_llm.invoke.return_value = make_mock_llm_response_reporter(
"Default Locale Report"
)
mock_get_llm.return_value = mock_llm
result = reporter_node(state, MagicMock())
assert isinstance(result, dict)
assert "final_report" in result
assert result["final_report"] == "Default Locale Report"
# Create the real Step class for the tests
class Step:
def __init__(self, title, description, execution_res=None):
self.title = title
self.description = description
self.execution_res = execution_res
@pytest.fixture
def mock_step():
return Step(title="Step 1", description="Desc 1", execution_res=None)
@pytest.fixture
def mock_completed_step():
return Step(title="Step 0", description="Desc 0", execution_res="Done")
@pytest.fixture
def mock_state_with_steps(mock_step, mock_completed_step):
# Simulate a plan with one completed and one unexecuted step
Plan = MagicMock()
Plan.steps = [mock_completed_step, mock_step]
return {
"current_plan": Plan,
"observations": ["obs1"],
"locale": "en-US",
"resources": [],
}
@pytest.fixture
def mock_state_no_unexecuted():
Step = namedtuple("Step", ["title", "description", "execution_res"])
Plan = MagicMock()
Plan.steps = [
Step(title="Step 1", description="Desc 1", execution_res="done"),
Step(title="Step 2", description="Desc 2", execution_res="done"),
]
return {
"current_plan": Plan,
"observations": [],
"locale": "en-US",
"resources": [],
}
@pytest.fixture
def mock_agent():
agent = MagicMock()
async def ainvoke(input, config):
# Simulate agent returning a message list
return {"messages": [MagicMock(content="result content")]}
agent.ainvoke = ainvoke
return agent
@pytest.mark.asyncio
async def test_execute_agent_step_basic(mock_state_with_steps, mock_agent):
# Should execute the first unexecuted step and update execution_res
with patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
):
result = await _execute_agent_step(
mock_state_with_steps, mock_agent, "researcher"
)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert "messages" in result.update
assert "observations" in result.update
# The new observation should be appended
assert result.update["observations"][-1] == "result content"
# The step's execution_res should be updated
assert (
mock_state_with_steps["current_plan"].steps[1].execution_res
== "result content"
)
@pytest.mark.asyncio
async def test_execute_agent_step_no_unexecuted_step(
mock_state_no_unexecuted, mock_agent
):
# Should return Command with goto="research_team" and not fail
with patch("src.graph.nodes.logger") as mock_logger:
result = await _execute_agent_step(
mock_state_no_unexecuted, mock_agent, "researcher"
)
assert isinstance(result, Command)
assert result.goto == "research_team"
mock_logger.warning.assert_called_with("No unexecuted step found")
@pytest.mark.asyncio
async def test_execute_agent_step_with_resources_and_researcher(mock_step):
# Should add resource info and citation reminder for researcher
Resource = namedtuple("Resource", ["title", "description"])
resources = [Resource(title="file1.txt", description="desc1")]
Plan = MagicMock()
Plan.steps = [mock_step]
state = {
"current_plan": Plan,
"observations": [],
"locale": "en-US",
"resources": resources,
}
agent = MagicMock()
async def ainvoke(input, config):
# Check that resource info and citation reminder are present
messages = input["messages"]
assert any("local_search_tool" in m.content for m in messages)
assert any("DO NOT include inline citations" in m.content for m in messages)
return {"messages": [MagicMock(content="resource result")]}
agent.ainvoke = ainvoke
with patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(content=content, name=name),
):
result = await _execute_agent_step(state, agent, "researcher")
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["observations"][-1] == "resource result"
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should respect AGENT_RECURSION_LIMIT env variable if set and valid
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "42")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.info.assert_any_call("Recursion limit set to: 42")
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env_invalid(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should fallback to default if env variable is invalid
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "notanint")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.warning.assert_any_call(
"Invalid AGENT_RECURSION_LIMIT value: 'notanint'. Using default value 25."
)
@pytest.mark.asyncio
async def test_execute_agent_step_recursion_limit_env_negative(
monkeypatch, mock_state_with_steps, mock_agent
):
# Should fallback to default if env variable is negative or zero
monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5")
with (
patch("src.graph.nodes.logger") as mock_logger,
patch(
"src.graph.nodes.HumanMessage",
side_effect=lambda content, name=None: MagicMock(
content=content, name=name
),
),
):
result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder")
assert isinstance(result, Command)
mock_logger.warning.assert_any_call(
"AGENT_RECURSION_LIMIT value '-5' (parsed as -5) is not positive. Using default value 25."
)
@pytest.fixture
def mock_configurable_with_mcp():
mock = MagicMock()
mock.mcp_settings = {
"servers": {
"server1": {
"enabled_tools": ["toolA", "toolB"],
"add_to_agents": ["researcher"],
"transport": "http",
"command": "run",
"args": {},
"url": "http://localhost",
"env": {},
"other": "ignore",
}
}
}
return mock
@pytest.fixture
def mock_configurable_without_mcp():
mock = MagicMock()
mock.mcp_settings = None
return mock
@pytest.fixture
def patch_config_from_runnable_config_with_mcp(mock_configurable_with_mcp):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_with_mcp,
):
yield
@pytest.fixture
def patch_config_from_runnable_config_without_mcp(mock_configurable_without_mcp):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable_without_mcp,
):
yield
@pytest.fixture
def patch_create_agent():
with patch("src.graph.nodes.create_agent") as mock:
yield mock
@pytest.fixture
def patch_execute_agent_step():
async def fake_execute_agent_step(state, agent, agent_type):
return "EXECUTED"
with patch(
"src.graph.nodes._execute_agent_step", side_effect=fake_execute_agent_step
) as mock:
yield mock
@pytest.fixture
def patch_multiserver_mcp_client():
# Patch MultiServerMCPClient as async context manager
class FakeTool:
def __init__(self, name, description="desc"):
self.name = name
self.description = description
class FakeClient:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
pass
async def get_tools(self):
return [
FakeTool("toolA", "descA"),
FakeTool("toolB", "descB"),
FakeTool("toolC", "descC"),
]
with patch(
"src.graph.nodes.MultiServerMCPClient", return_value=FakeClient()
) as mock:
yield mock
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_with_mcp,
patch_create_agent,
patch_execute_agent_step,
patch_multiserver_mcp_client,
):
# Should use MCP client, load tools, and call create_agent with correct tools
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# Should call create_agent with loaded_tools including toolA and toolB
args, kwargs = patch_create_agent.call_args
loaded_tools = args[2]
tool_names = [t.name for t in loaded_tools if hasattr(t, "name")]
assert "toolA" in tool_names
assert "toolB" in tool_names
# Should call _execute_agent_step
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_without_mcp(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_without_mcp,
patch_create_agent,
patch_execute_agent_step,
):
# Should use default tools and not use MCP client
default_tools = [MagicMock(name="default_tool")]
agent_type = "coder"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# Should call create_agent with default_tools
args, kwargs = patch_create_agent.call_args
assert args[2] == default_tools
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp_no_enabled_tools(
mock_state_with_steps,
mock_config,
patch_create_agent,
patch_execute_agent_step,
):
# If mcp_settings present but no enabled_tools for agent_type, should fallback to default_tools
mcp_settings = {
"servers": {
"server1": {
"enabled_tools": ["toolA"],
"add_to_agents": ["other_agent"],
"transport": "http",
"command": "run",
"args": {},
"url": "http://localhost",
"env": {},
}
}
}
configurable = MagicMock()
configurable.mcp_settings = mcp_settings
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=configurable,
):
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
result = await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
args, kwargs = patch_create_agent.call_args
assert args[2] == default_tools
patch_execute_agent_step.assert_called_once()
assert result == "EXECUTED"
@pytest.mark.asyncio
async def test_setup_and_execute_agent_step_with_mcp_tools_description_update(
mock_state_with_steps,
mock_config,
patch_config_from_runnable_config_with_mcp,
patch_create_agent,
patch_execute_agent_step,
):
# Should update tool.description with Powered by info
default_tools = [MagicMock(name="default_tool")]
agent_type = "researcher"
# Patch MultiServerMCPClient to check description update
class FakeTool:
def __init__(self, name, description="desc"):
self.name = name
self.description = description
class FakeClient:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
pass
async def get_tools(self):
return [FakeTool("toolA", "descA")]
with patch("src.graph.nodes.MultiServerMCPClient", return_value=FakeClient()):
await _setup_and_execute_agent_step(
mock_state_with_steps,
mock_config,
agent_type,
default_tools,
)
# The tool description should be updated
args, kwargs = patch_create_agent.call_args
loaded_tools = args[2]
found = False
for t in loaded_tools:
if hasattr(t, "name") and t.name == "toolA":
assert t.description.startswith("Powered by 'server1'.\n")
found = True
assert found
@pytest.fixture
def mock_state_with_resources():
return {"resources": ["resource1", "resource2"], "other": "value"}
@pytest.fixture
def mock_state_without_resources():
return {"other": "value"}
@pytest.fixture
def patch_get_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
mock_tool = MagicMock(name="web_search_tool")
mock.return_value = mock_tool
yield mock
@pytest.fixture
def patch_crawl_tool():
with patch("src.graph.nodes.crawl_tool", MagicMock(name="crawl_tool")):
yield
@pytest.fixture
def patch_get_retriever_tool():
with patch("src.graph.nodes.get_retriever_tool") as mock:
yield mock
@pytest.fixture
def patch_setup_and_execute_agent_step():
async def fake_setup_and_execute_agent_step(state, config, agent_type, tools):
return "RESEARCHER_RESULT"
with patch(
"src.graph.nodes._setup_and_execute_agent_step",
side_effect=fake_setup_and_execute_agent_step,
) as mock:
yield mock
@pytest.mark.asyncio
async def test_researcher_node_with_retriever_tool(
mock_state_with_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
# Simulate retriever_tool is returned
retriever_tool = MagicMock(name="retriever_tool")
patch_get_retriever_tool.return_value = retriever_tool
result = await researcher_node(mock_state_with_resources, mock_config)
# Should call get_web_search_tool with correct max_search_results
patch_get_web_search_tool.assert_called_once_with(7)
# Should call get_retriever_tool with resources
patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"])
# Should call _setup_and_execute_agent_step with retriever_tool first
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
assert tools[0] == retriever_tool
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"
@pytest.mark.asyncio
async def test_researcher_node_without_retriever_tool(
mock_state_with_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
# Simulate retriever_tool is None
patch_get_retriever_tool.return_value = None
result = await researcher_node(mock_state_with_resources, mock_config)
patch_get_web_search_tool.assert_called_once_with(7)
patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"])
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
# Should not include retriever_tool
assert all(getattr(t, "name", None) != "retriever_tool" for t in tools)
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"
@pytest.mark.asyncio
async def test_researcher_node_without_resources(
mock_state_without_resources,
mock_config,
patch_config_from_runnable_config,
patch_get_web_search_tool,
patch_crawl_tool,
patch_get_retriever_tool,
patch_setup_and_execute_agent_step,
):
patch_get_retriever_tool.return_value = None
result = await researcher_node(mock_state_without_resources, mock_config)
patch_get_web_search_tool.assert_called_once_with(7)
patch_get_retriever_tool.assert_called_once_with([])
args, kwargs = patch_setup_and_execute_agent_step.call_args
tools = args[3]
assert patch_get_web_search_tool.return_value in tools
assert result == "RESEARCHER_RESULT"