diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 5fa574e..8370e21 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -190,7 +190,7 @@ def human_feedback_node( goto = "reporter" except json.JSONDecodeError: logger.warning("Planner response is not a valid JSON") - if plan_iterations > 0: + if plan_iterations > 1: # the plan_iterations is increased before this check return Command(goto="reporter") else: return Command(goto="__end__") diff --git a/tests/integration/test_nodes.py b/tests/integration/test_nodes.py index 55c8e48..529f1b8 100644 --- a/tests/integration/test_nodes.py +++ b/tests/integration/test_nodes.py @@ -1,6 +1,16 @@ +from collections import namedtuple import json import pytest -from unittest.mock import patch, MagicMock +import asyncio +import types +from unittest.mock import patch, MagicMock, AsyncMock +from src.graph.nodes import planner_node +from src.graph.nodes import human_feedback_node +from src.graph.nodes import coordinator_node +from src.graph.nodes import reporter_node +from src.graph.nodes import _execute_agent_step +from src.graph.nodes import _setup_and_execute_agent_step +from src.graph.nodes import researcher_node # 在这里 mock 掉 get_llm_by_type,避免 ValueError with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()): @@ -9,6 +19,7 @@ with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()): from src.config import SearchEngine from langchain_core.messages import HumanMessage + # Mock data MOCK_SEARCH_RESULTS = [ {"title": "Test Title 1", "content": "Test Content 1"}, @@ -123,3 +134,1291 @@ def test_background_investigation_node_malformed_response( # Parse and verify the JSON content results = result["background_investigation_results"] assert json.loads(results) is None + + +@pytest.fixture +def mock_plan(): + return { + "has_enough_context": True, + "title": "Test Plan", + "thought": "Test Thought", + "steps": [], + "locale": "en-US", + } + + +@pytest.fixture +def mock_state_planner(): + return { + "messages": [HumanMessage(content="plan this")], + "plan_iterations": 0, + "enable_background_investigation": True, + "background_investigation_results": "Background info", + } + + +@pytest.fixture +def mock_configurable_planner(): + mock = MagicMock() + mock.max_plan_iterations = 3 + return mock + + +@pytest.fixture +def patch_config_from_runnable_config_planner(mock_configurable_planner): + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable_planner, + ): + yield + + +@pytest.fixture +def patch_apply_prompt_template(): + with patch( + "src.graph.nodes.apply_prompt_template", + return_value=[{"role": "user", "content": "plan this"}], + ) as mock: + yield mock + + +@pytest.fixture +def patch_repair_json_output(): + with patch("src.graph.nodes.repair_json_output", side_effect=lambda x: x) as mock: + yield mock + + +@pytest.fixture +def patch_plan_model_validate(): + with patch("src.graph.nodes.Plan.model_validate", side_effect=lambda x: x) as mock: + yield mock + + +@pytest.fixture +def patch_ai_message(): + AIMessage = namedtuple("AIMessage", ["content", "name"]) + with patch( + "src.graph.nodes.AIMessage", + side_effect=lambda content, name: AIMessage(content, name), + ) as mock: + yield mock + + +def test_planner_node_basic_has_enough_context( + mock_state_planner, + patch_config_from_runnable_config_planner, + patch_apply_prompt_template, + patch_repair_json_output, + patch_plan_model_validate, + patch_ai_message, + mock_plan, +): + # AGENT_LLM_MAP["planner"] == "basic" + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value = mock_llm + mock_response = MagicMock() + mock_response.model_dump_json.return_value = json.dumps(mock_plan) + mock_llm.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = planner_node(mock_state_planner, MagicMock()) + assert isinstance(result, Command) + assert result.goto == "reporter" + assert "current_plan" in result.update + assert result.update["current_plan"]["has_enough_context"] is True + assert result.update["messages"][0].name == "planner" + + +def test_planner_node_basic_not_enough_context( + mock_state_planner, + patch_config_from_runnable_config_planner, + patch_apply_prompt_template, + patch_repair_json_output, + patch_plan_model_validate, + patch_ai_message, +): + # AGENT_LLM_MAP["planner"] == "basic" + plan = { + "has_enough_context": False, + "title": "Test Plan", + "thought": "Test Thought", + "steps": [], + "locale": "en-US", + } + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value = mock_llm + mock_response = MagicMock() + mock_response.model_dump_json.return_value = json.dumps(plan) + mock_llm.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = planner_node(mock_state_planner, MagicMock()) + assert isinstance(result, Command) + assert result.goto == "human_feedback" + assert "current_plan" in result.update + assert isinstance(result.update["current_plan"], str) + assert result.update["messages"][0].name == "planner" + + +def test_planner_node_stream_mode_has_enough_context( + mock_state_planner, + patch_config_from_runnable_config_planner, + patch_apply_prompt_template, + patch_repair_json_output, + patch_plan_model_validate, + patch_ai_message, + mock_plan, +): + # AGENT_LLM_MAP["planner"] != "basic" + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + # Simulate streaming chunks + chunk = MagicMock() + chunk.content = json.dumps(mock_plan) + mock_llm.stream.return_value = [chunk] + mock_get_llm.return_value = mock_llm + + result = planner_node(mock_state_planner, MagicMock()) + assert isinstance(result, Command) + assert result.goto == "reporter" + assert "current_plan" in result.update + assert result.update["current_plan"]["has_enough_context"] is True + + +def test_planner_node_stream_mode_not_enough_context( + mock_state_planner, + patch_config_from_runnable_config_planner, + patch_apply_prompt_template, + patch_repair_json_output, + patch_plan_model_validate, + patch_ai_message, +): + # AGENT_LLM_MAP["planner"] != "basic" + plan = { + "has_enough_context": False, + "title": "Test Plan", + "thought": "Test Thought", + "steps": [], + "locale": "en-US", + } + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "other"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + chunk = MagicMock() + chunk.content = json.dumps(plan) + mock_llm.stream.return_value = [chunk] + mock_get_llm.return_value = mock_llm + + result = planner_node(mock_state_planner, MagicMock()) + assert isinstance(result, Command) + assert result.goto == "human_feedback" + assert "current_plan" in result.update + assert isinstance(result.update["current_plan"], str) + + +def test_planner_node_plan_iterations_exceeded(mock_state_planner): + # plan_iterations >= max_plan_iterations + state = dict(mock_state_planner) + state["plan_iterations"] = 5 + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}), + patch("src.graph.nodes.get_llm_by_type", return_value=MagicMock()), + ): + result = planner_node(state, MagicMock()) + assert isinstance(result, Command) + assert result.goto == "reporter" + + +def test_planner_node_json_decode_error_first_iteration(mock_state_planner): + # Simulate JSONDecodeError on first iteration + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + patch( + "src.graph.nodes.json.loads", + side_effect=json.JSONDecodeError("err", "doc", 0), + ), + ): + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value = mock_llm + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"bad": "json"' + mock_llm.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = planner_node(mock_state_planner, MagicMock()) + assert isinstance(result, Command) + assert result.goto == "__end__" + + +def test_planner_node_json_decode_error_second_iteration(mock_state_planner): + # Simulate JSONDecodeError on second iteration + state = dict(mock_state_planner) + state["plan_iterations"] = 1 + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"planner": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + patch( + "src.graph.nodes.json.loads", + side_effect=json.JSONDecodeError("err", "doc", 0), + ), + ): + mock_llm = MagicMock() + mock_llm.with_structured_output.return_value = mock_llm + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"bad": "json"' + mock_llm.invoke.return_value = mock_response + mock_get_llm.return_value = mock_llm + + result = planner_node(state, MagicMock()) + assert isinstance(result, Command) + assert result.goto == "reporter" + + +# Patch Plan.model_validate and repair_json_output globally for these tests +@pytest.fixture(autouse=True) +def patch_plan_and_repair(monkeypatch): + monkeypatch.setattr("src.graph.nodes.Plan.model_validate", lambda x: x) + monkeypatch.setattr("src.graph.nodes.repair_json_output", lambda x: x) + yield + + +@pytest.fixture +def mock_state_base(): + return { + "current_plan": json.dumps( + { + "has_enough_context": True, + "title": "Test Plan", + "thought": "Test Thought", + "steps": [], + "locale": "en-US", + } + ), + "plan_iterations": 0, + } + + +def test_human_feedback_node_auto_accepted(monkeypatch, mock_state_base): + # auto_accepted_plan True, should skip interrupt and parse plan + state = dict(mock_state_base) + state["auto_accepted_plan"] = True + result = human_feedback_node(state) + assert isinstance(result, Command) + assert result.goto == "reporter" + assert result.update["plan_iterations"] == 1 + assert result.update["current_plan"]["has_enough_context"] is True + + +def test_human_feedback_node_edit_plan(monkeypatch, mock_state_base): + # interrupt returns [EDIT_PLAN]..., should return Command to planner + state = dict(mock_state_base) + state["auto_accepted_plan"] = False + with patch("src.graph.nodes.interrupt", return_value="[EDIT_PLAN] Please revise"): + result = human_feedback_node(state) + assert isinstance(result, Command) + assert result.goto == "planner" + assert result.update["messages"][0].name == "feedback" + assert "[EDIT_PLAN]" in result.update["messages"][0].content + + +def test_human_feedback_node_accepted(monkeypatch, mock_state_base): + # interrupt returns [ACCEPTED]..., should proceed to parse plan + state = dict(mock_state_base) + state["auto_accepted_plan"] = False + with patch("src.graph.nodes.interrupt", return_value="[ACCEPTED] Looks good!"): + result = human_feedback_node(state) + assert isinstance(result, Command) + assert result.goto == "reporter" + assert result.update["plan_iterations"] == 1 + assert result.update["current_plan"]["has_enough_context"] is True + + +def test_human_feedback_node_invalid_interrupt(monkeypatch, mock_state_base): + # interrupt returns something else, should raise TypeError + state = dict(mock_state_base) + state["auto_accepted_plan"] = False + with patch("src.graph.nodes.interrupt", return_value="RANDOM_FEEDBACK"): + with pytest.raises(TypeError): + human_feedback_node(state) + + +def test_human_feedback_node_json_decode_error_first_iteration( + monkeypatch, mock_state_base +): + # repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations=0 + state = dict(mock_state_base) + state["auto_accepted_plan"] = True + state["plan_iterations"] = 0 + with patch( + "src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) + ): + result = human_feedback_node(state) + assert isinstance(result, Command) + assert result.goto == "__end__" + + +def test_human_feedback_node_json_decode_error_second_iteration( + monkeypatch, mock_state_base +): + # repair_json_output returns bad json, json.loads raises JSONDecodeError, plan_iterations>0 + state = dict(mock_state_base) + state["auto_accepted_plan"] = True + state["plan_iterations"] = 2 + with patch( + "src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) + ): + result = human_feedback_node(state) + assert isinstance(result, Command) + assert result.goto == "reporter" + + +def test_human_feedback_node_not_enough_context(monkeypatch, mock_state_base): + # Plan does not have enough context, should goto research_team + plan = { + "has_enough_context": False, + "title": "Test Plan", + "thought": "Test Thought", + "steps": [], + "locale": "en-US", + } + state = dict(mock_state_base) + state["current_plan"] = json.dumps(plan) + state["auto_accepted_plan"] = True + result = human_feedback_node(state) + assert isinstance(result, Command) + assert result.goto == "research_team" + assert result.update["plan_iterations"] == 1 + assert result.update["current_plan"]["has_enough_context"] is False + + +@pytest.fixture +def mock_state_coordinator(): + return { + "messages": [{"role": "user", "content": "test"}], + "locale": "en-US", + } + + +@pytest.fixture +def mock_configurable_coordinator(): + mock = MagicMock() + mock.resources = ["resource1", "resource2"] + return mock + + +@pytest.fixture +def patch_config_from_runnable_config_coordinator(mock_configurable_coordinator): + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable_coordinator, + ): + yield + + +@pytest.fixture +def patch_apply_prompt_template_coordinator(): + with patch( + "src.graph.nodes.apply_prompt_template", + return_value=[{"role": "user", "content": "test"}], + ) as mock: + yield mock + + +@pytest.fixture +def patch_handoff_to_planner(): + with patch("src.graph.nodes.handoff_to_planner", MagicMock()): + yield + + +@pytest.fixture +def patch_logger(): + with patch("src.graph.nodes.logger") as mock_logger: + yield mock_logger + + +def make_mock_llm_response(tool_calls=None): + resp = MagicMock() + resp.tool_calls = tool_calls or [] + return resp + + +def test_coordinator_node_no_tool_calls( + mock_state_coordinator, + patch_config_from_runnable_config_coordinator, + patch_apply_prompt_template_coordinator, + patch_handoff_to_planner, + patch_logger, +): + # No tool calls, should goto __end__ + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.bind_tools.return_value = mock_llm + mock_llm.invoke.return_value = make_mock_llm_response([]) + mock_get_llm.return_value = mock_llm + + result = coordinator_node(mock_state_coordinator, MagicMock()) + assert result.goto == "__end__" + assert result.update["locale"] == "en-US" + assert result.update["resources"] == ["resource1", "resource2"] + + +def test_coordinator_node_with_tool_calls_planner( + mock_state_coordinator, + patch_config_from_runnable_config_coordinator, + patch_apply_prompt_template_coordinator, + patch_handoff_to_planner, + patch_logger, +): + # tool_calls present, should goto planner + tool_calls = [{"name": "handoff_to_planner", "args": {}}] + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.bind_tools.return_value = mock_llm + mock_llm.invoke.return_value = make_mock_llm_response(tool_calls) + mock_get_llm.return_value = mock_llm + + result = coordinator_node(mock_state_coordinator, MagicMock()) + assert result.goto == "planner" + assert result.update["locale"] == "en-US" + assert result.update["resources"] == ["resource1", "resource2"] + + +def test_coordinator_node_with_tool_calls_background_investigator( + mock_state_coordinator, + patch_config_from_runnable_config_coordinator, + patch_apply_prompt_template_coordinator, + patch_handoff_to_planner, + patch_logger, +): + # enable_background_investigation True, should goto background_investigator + state = dict(mock_state_coordinator) + state["enable_background_investigation"] = True + tool_calls = [{"name": "handoff_to_planner", "args": {}}] + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.bind_tools.return_value = mock_llm + mock_llm.invoke.return_value = make_mock_llm_response(tool_calls) + mock_get_llm.return_value = mock_llm + + result = coordinator_node(state, MagicMock()) + assert result.goto == "background_investigator" + assert result.update["locale"] == "en-US" + assert result.update["resources"] == ["resource1", "resource2"] + + +def test_coordinator_node_with_tool_calls_locale_override( + mock_state_coordinator, + patch_config_from_runnable_config_coordinator, + patch_apply_prompt_template_coordinator, + patch_handoff_to_planner, + patch_logger, +): + # tool_calls with locale in args should override locale + tool_calls = [ + { + "name": "handoff_to_planner", + "args": {"locale": "zh-CN", "research_topic": "test topic"}, + } + ] + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.bind_tools.return_value = mock_llm + mock_llm.invoke.return_value = make_mock_llm_response(tool_calls) + mock_get_llm.return_value = mock_llm + + result = coordinator_node(mock_state_coordinator, MagicMock()) + assert result.goto == "planner" + assert result.update["locale"] == "zh-CN" + assert result.update["research_topic"] == "test topic" + assert result.update["resources"] == ["resource1", "resource2"] + assert result.update["resources"] == ["resource1", "resource2"] + + +def test_coordinator_node_tool_calls_exception_handling( + mock_state_coordinator, + patch_config_from_runnable_config_coordinator, + patch_apply_prompt_template_coordinator, + patch_handoff_to_planner, + patch_logger, +): + # tool_calls raises exception in processing + tool_calls = [{"name": "handoff_to_planner", "args": None}] + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"coordinator": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.bind_tools.return_value = mock_llm + + # Simulate tool_call.get("args", {}) raising AttributeError + class BadToolCall(dict): + def get(self, key, default=None): + if key == "args": + raise Exception("bad args") + return super().get(key, default) + + mock_llm.invoke.return_value = make_mock_llm_response( + [BadToolCall({"name": "handoff_to_planner"})] + ) + mock_get_llm.return_value = mock_llm + + # Should not raise, just log error and continue + result = coordinator_node(mock_state_coordinator, MagicMock()) + assert result.goto == "planner" + assert result.update["locale"] == "en-US" + assert result.update["resources"] == ["resource1", "resource2"] + + +@pytest.fixture +def mock_state_reporter(): + # Simulate a plan object with title and thought attributes + Plan = namedtuple("Plan", ["title", "thought"]) + return { + "current_plan": Plan(title="Test Title", thought="Test Thought"), + "locale": "en-US", + "observations": [], + } + + +@pytest.fixture +def mock_state_reporter_with_observations(): + Plan = namedtuple("Plan", ["title", "thought"]) + return { + "current_plan": Plan(title="Test Title", thought="Test Thought"), + "locale": "en-US", + "observations": ["Observation 1", "Observation 2"], + } + + +@pytest.fixture +def mock_configurable_reporter(): + mock = MagicMock() + return mock + + +@pytest.fixture +def patch_config_from_runnable_config_reporter(mock_configurable_reporter): + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable_reporter, + ): + yield + + +@pytest.fixture +def patch_apply_prompt_template_reporter(): + with patch( + "src.graph.nodes.apply_prompt_template", + side_effect=lambda *args, **kwargs: [MagicMock()], + ) as mock: + yield mock + + +@pytest.fixture +def patch_human_message(): + HumanMessage = MagicMock() + with patch("src.graph.nodes.HumanMessage", HumanMessage): + yield HumanMessage + + +@pytest.fixture +def patch_logger_reporter(): + with patch("src.graph.nodes.logger") as mock_logger: + yield mock_logger + + +def make_mock_llm_response_reporter(content): + resp = MagicMock() + resp.content = content + return resp + + +def test_reporter_node_basic( + mock_state_reporter, + patch_config_from_runnable_config_reporter, + patch_apply_prompt_template_reporter, + patch_human_message, + patch_logger_reporter, +): + # Patch get_llm_by_type and AGENT_LLM_MAP + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.invoke.return_value = make_mock_llm_response_reporter( + "Final Report Content" + ) + mock_get_llm.return_value = mock_llm + + result = reporter_node(mock_state_reporter, MagicMock()) + assert isinstance(result, dict) + assert "final_report" in result + assert result["final_report"] == "Final Report Content" + # Should call apply_prompt_template with correct arguments + patch_apply_prompt_template_reporter.assert_called() + # Should call invoke on the LLM + mock_llm.invoke.assert_called() + + +def test_reporter_node_with_observations( + mock_state_reporter_with_observations, + patch_config_from_runnable_config_reporter, + patch_apply_prompt_template_reporter, + patch_human_message, + patch_logger_reporter, +): + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.invoke.return_value = make_mock_llm_response_reporter( + "Report with Observations" + ) + mock_get_llm.return_value = mock_llm + + result = reporter_node(mock_state_reporter_with_observations, MagicMock()) + assert isinstance(result, dict) + assert "final_report" in result + assert result["final_report"] == "Report with Observations" + # Should call apply_prompt_template with correct arguments + patch_apply_prompt_template_reporter.assert_called() + # Should call invoke on the LLM + mock_llm.invoke.assert_called() + + +def test_reporter_node_locale_default( + patch_config_from_runnable_config_reporter, + patch_apply_prompt_template_reporter, + patch_human_message, + patch_logger_reporter, +): + # If locale is missing, should default to "en-US" + Plan = namedtuple("Plan", ["title", "thought"]) + state = { + "current_plan": Plan(title="Test Title", thought="Test Thought"), + # "locale" omitted + "observations": [], + } + with ( + patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}), + patch("src.graph.nodes.get_llm_by_type") as mock_get_llm, + ): + mock_llm = MagicMock() + mock_llm.invoke.return_value = make_mock_llm_response_reporter( + "Default Locale Report" + ) + mock_get_llm.return_value = mock_llm + + result = reporter_node(state, MagicMock()) + assert isinstance(result, dict) + assert "final_report" in result + assert result["final_report"] == "Default Locale Report" + + +# Create the real Step class for the tests +class Step: + def __init__(self, title, description, execution_res=None): + self.title = title + self.description = description + self.execution_res = execution_res + + +@pytest.fixture +def mock_step(): + return Step(title="Step 1", description="Desc 1", execution_res=None) + + +@pytest.fixture +def mock_completed_step(): + return Step(title="Step 0", description="Desc 0", execution_res="Done") + + +@pytest.fixture +def mock_state_with_steps(mock_step, mock_completed_step): + # Simulate a plan with one completed and one unexecuted step + Plan = MagicMock() + Plan.steps = [mock_completed_step, mock_step] + return { + "current_plan": Plan, + "observations": ["obs1"], + "locale": "en-US", + "resources": [], + } + + +@pytest.fixture +def mock_state_no_unexecuted(): + Step = namedtuple("Step", ["title", "description", "execution_res"]) + Plan = MagicMock() + Plan.steps = [ + Step(title="Step 1", description="Desc 1", execution_res="done"), + Step(title="Step 2", description="Desc 2", execution_res="done"), + ] + return { + "current_plan": Plan, + "observations": [], + "locale": "en-US", + "resources": [], + } + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + + async def ainvoke(input, config): + # Simulate agent returning a message list + return {"messages": [MagicMock(content="result content")]} + + agent.ainvoke = ainvoke + return agent + + +@pytest.mark.asyncio +async def test_execute_agent_step_basic(mock_state_with_steps, mock_agent): + # Should execute the first unexecuted step and update execution_res + with patch( + "src.graph.nodes.HumanMessage", + side_effect=lambda content, name=None: MagicMock(content=content, name=name), + ): + result = await _execute_agent_step( + mock_state_with_steps, mock_agent, "researcher" + ) + assert isinstance(result, Command) + assert result.goto == "research_team" + assert "messages" in result.update + assert "observations" in result.update + # The new observation should be appended + assert result.update["observations"][-1] == "result content" + # The step's execution_res should be updated + assert ( + mock_state_with_steps["current_plan"].steps[1].execution_res + == "result content" + ) + + +@pytest.mark.asyncio +async def test_execute_agent_step_no_unexecuted_step( + mock_state_no_unexecuted, mock_agent +): + # Should return Command with goto="research_team" and not fail + with patch("src.graph.nodes.logger") as mock_logger: + result = await _execute_agent_step( + mock_state_no_unexecuted, mock_agent, "researcher" + ) + assert isinstance(result, Command) + assert result.goto == "research_team" + mock_logger.warning.assert_called_with("No unexecuted step found") + + +@pytest.mark.asyncio +async def test_execute_agent_step_with_resources_and_researcher(mock_step): + # Should add resource info and citation reminder for researcher + Resource = namedtuple("Resource", ["title", "description"]) + resources = [Resource(title="file1.txt", description="desc1")] + Plan = MagicMock() + Plan.steps = [mock_step] + state = { + "current_plan": Plan, + "observations": [], + "locale": "en-US", + "resources": resources, + } + agent = MagicMock() + + async def ainvoke(input, config): + # Check that resource info and citation reminder are present + messages = input["messages"] + assert any("local_search_tool" in m.content for m in messages) + assert any("DO NOT include inline citations" in m.content for m in messages) + return {"messages": [MagicMock(content="resource result")]} + + agent.ainvoke = ainvoke + with patch( + "src.graph.nodes.HumanMessage", + side_effect=lambda content, name=None: MagicMock(content=content, name=name), + ): + result = await _execute_agent_step(state, agent, "researcher") + assert isinstance(result, Command) + assert result.goto == "research_team" + assert result.update["observations"][-1] == "resource result" + + +@pytest.mark.asyncio +async def test_execute_agent_step_recursion_limit_env( + monkeypatch, mock_state_with_steps, mock_agent +): + # Should respect AGENT_RECURSION_LIMIT env variable if set and valid + monkeypatch.setenv("AGENT_RECURSION_LIMIT", "42") + with ( + patch("src.graph.nodes.logger") as mock_logger, + patch( + "src.graph.nodes.HumanMessage", + side_effect=lambda content, name=None: MagicMock( + content=content, name=name + ), + ), + ): + result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder") + assert isinstance(result, Command) + mock_logger.info.assert_any_call("Recursion limit set to: 42") + + +@pytest.mark.asyncio +async def test_execute_agent_step_recursion_limit_env_invalid( + monkeypatch, mock_state_with_steps, mock_agent +): + # Should fallback to default if env variable is invalid + monkeypatch.setenv("AGENT_RECURSION_LIMIT", "notanint") + with ( + patch("src.graph.nodes.logger") as mock_logger, + patch( + "src.graph.nodes.HumanMessage", + side_effect=lambda content, name=None: MagicMock( + content=content, name=name + ), + ), + ): + result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder") + assert isinstance(result, Command) + mock_logger.warning.assert_any_call( + "Invalid AGENT_RECURSION_LIMIT value: 'notanint'. Using default value 25." + ) + + +@pytest.mark.asyncio +async def test_execute_agent_step_recursion_limit_env_negative( + monkeypatch, mock_state_with_steps, mock_agent +): + # Should fallback to default if env variable is negative or zero + monkeypatch.setenv("AGENT_RECURSION_LIMIT", "-5") + with ( + patch("src.graph.nodes.logger") as mock_logger, + patch( + "src.graph.nodes.HumanMessage", + side_effect=lambda content, name=None: MagicMock( + content=content, name=name + ), + ), + ): + result = await _execute_agent_step(mock_state_with_steps, mock_agent, "coder") + assert isinstance(result, Command) + mock_logger.warning.assert_any_call( + "AGENT_RECURSION_LIMIT value '-5' (parsed as -5) is not positive. Using default value 25." + ) + + +@pytest.fixture +def mock_state_with_steps(mock_step, mock_completed_step): + # Simulate a plan with one completed and one unexecuted step + Plan = MagicMock() + Plan.steps = [mock_completed_step, mock_step] + return { + "current_plan": Plan, + "observations": ["obs1"], + "locale": "en-US", + "resources": [], + } + + +@pytest.fixture +def mock_config(): + return MagicMock() + + +@pytest.fixture +def mock_configurable_with_mcp(): + mock = MagicMock() + mock.mcp_settings = { + "servers": { + "server1": { + "enabled_tools": ["toolA", "toolB"], + "add_to_agents": ["researcher"], + "transport": "http", + "command": "run", + "args": {}, + "url": "http://localhost", + "env": {}, + "other": "ignore", + } + } + } + return mock + + +@pytest.fixture +def mock_configurable_without_mcp(): + mock = MagicMock() + mock.mcp_settings = None + return mock + + +@pytest.fixture +def patch_config_from_runnable_config_with_mcp(mock_configurable_with_mcp): + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable_with_mcp, + ): + yield + + +@pytest.fixture +def patch_config_from_runnable_config_without_mcp(mock_configurable_without_mcp): + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable_without_mcp, + ): + yield + + +@pytest.fixture +def patch_create_agent(): + with patch("src.graph.nodes.create_agent") as mock: + yield mock + + +@pytest.fixture +def patch_execute_agent_step(): + async def fake_execute_agent_step(state, agent, agent_type): + return "EXECUTED" + + with patch( + "src.graph.nodes._execute_agent_step", side_effect=fake_execute_agent_step + ) as mock: + yield mock + + +@pytest.fixture +def patch_multiserver_mcp_client(): + # Patch MultiServerMCPClient as async context manager + class FakeTool: + def __init__(self, name, description="desc"): + self.name = name + self.description = description + + class FakeClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + pass + + def get_tools(self): + return [ + FakeTool("toolA", "descA"), + FakeTool("toolB", "descB"), + FakeTool("toolC", "descC"), + ] + + with patch( + "src.graph.nodes.MultiServerMCPClient", return_value=FakeClient() + ) as mock: + yield mock + + +@pytest.mark.asyncio +async def test_setup_and_execute_agent_step_with_mcp( + mock_state_with_steps, + mock_config, + patch_config_from_runnable_config_with_mcp, + patch_create_agent, + patch_execute_agent_step, + patch_multiserver_mcp_client, +): + # Should use MCP client, load tools, and call create_agent with correct tools + default_tools = [MagicMock(name="default_tool")] + agent_type = "researcher" + + result = await _setup_and_execute_agent_step( + mock_state_with_steps, + mock_config, + agent_type, + default_tools, + ) + # Should call create_agent with loaded_tools including toolA and toolB + args, kwargs = patch_create_agent.call_args + loaded_tools = args[2] + tool_names = [t.name for t in loaded_tools if hasattr(t, "name")] + assert "toolA" in tool_names + assert "toolB" in tool_names + # Should call _execute_agent_step + patch_execute_agent_step.assert_called_once() + assert result == "EXECUTED" + + +@pytest.mark.asyncio +async def test_setup_and_execute_agent_step_without_mcp( + mock_state_with_steps, + mock_config, + patch_config_from_runnable_config_without_mcp, + patch_create_agent, + patch_execute_agent_step, +): + # Should use default tools and not use MCP client + default_tools = [MagicMock(name="default_tool")] + agent_type = "coder" + + result = await _setup_and_execute_agent_step( + mock_state_with_steps, + mock_config, + agent_type, + default_tools, + ) + # Should call create_agent with default_tools + args, kwargs = patch_create_agent.call_args + assert args[2] == default_tools + patch_execute_agent_step.assert_called_once() + assert result == "EXECUTED" + + +@pytest.mark.asyncio +async def test_setup_and_execute_agent_step_with_mcp_no_enabled_tools( + mock_state_with_steps, + mock_config, + patch_create_agent, + patch_execute_agent_step, +): + # If mcp_settings present but no enabled_tools for agent_type, should fallback to default_tools + mcp_settings = { + "servers": { + "server1": { + "enabled_tools": ["toolA"], + "add_to_agents": ["other_agent"], + "transport": "http", + "command": "run", + "args": {}, + "url": "http://localhost", + "env": {}, + } + } + } + configurable = MagicMock() + configurable.mcp_settings = mcp_settings + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=configurable, + ): + default_tools = [MagicMock(name="default_tool")] + agent_type = "researcher" + result = await _setup_and_execute_agent_step( + mock_state_with_steps, + mock_config, + agent_type, + default_tools, + ) + args, kwargs = patch_create_agent.call_args + assert args[2] == default_tools + patch_execute_agent_step.assert_called_once() + assert result == "EXECUTED" + + +@pytest.mark.asyncio +async def test_setup_and_execute_agent_step_with_mcp_tools_description_update( + mock_state_with_steps, + mock_config, + patch_config_from_runnable_config_with_mcp, + patch_create_agent, + patch_execute_agent_step, +): + # Should update tool.description with Powered by info + default_tools = [MagicMock(name="default_tool")] + agent_type = "researcher" + + # Patch MultiServerMCPClient to check description update + class FakeTool: + def __init__(self, name, description="desc"): + self.name = name + self.description = description + + class FakeClient: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + pass + + def get_tools(self): + return [FakeTool("toolA", "descA")] + + with patch("src.graph.nodes.MultiServerMCPClient", return_value=FakeClient()): + await _setup_and_execute_agent_step( + mock_state_with_steps, + mock_config, + agent_type, + default_tools, + ) + # The tool description should be updated + args, kwargs = patch_create_agent.call_args + loaded_tools = args[2] + found = False + for t in loaded_tools: + if hasattr(t, "name") and t.name == "toolA": + assert t.description.startswith("Powered by 'server1'.\n") + found = True + assert found + + +@pytest.fixture +def mock_state_with_resources(): + return {"resources": ["resource1", "resource2"], "other": "value"} + + +@pytest.fixture +def mock_state_without_resources(): + return {"other": "value"} + + +@pytest.fixture +def mock_config(): + return MagicMock() + + +@pytest.fixture +def mock_configurable(): + mock = MagicMock() + mock.max_search_results = 7 + return mock + + +@pytest.fixture +def patch_config_from_runnable_config(mock_configurable): + with patch( + "src.graph.nodes.Configuration.from_runnable_config", + return_value=mock_configurable, + ): + yield + + +@pytest.fixture +def patch_get_web_search_tool(): + with patch("src.graph.nodes.get_web_search_tool") as mock: + mock_tool = MagicMock(name="web_search_tool") + mock.return_value = mock_tool + yield mock + + +@pytest.fixture +def patch_crawl_tool(): + with patch("src.graph.nodes.crawl_tool", MagicMock(name="crawl_tool")): + yield + + +@pytest.fixture +def patch_get_retriever_tool(): + with patch("src.graph.nodes.get_retriever_tool") as mock: + yield mock + + +@pytest.fixture +def patch_setup_and_execute_agent_step(): + async def fake_setup_and_execute_agent_step(state, config, agent_type, tools): + return "RESEARCHER_RESULT" + + with patch( + "src.graph.nodes._setup_and_execute_agent_step", + side_effect=fake_setup_and_execute_agent_step, + ) as mock: + yield mock + + +@pytest.mark.asyncio +async def test_researcher_node_with_retriever_tool( + mock_state_with_resources, + mock_config, + patch_config_from_runnable_config, + patch_get_web_search_tool, + patch_crawl_tool, + patch_get_retriever_tool, + patch_setup_and_execute_agent_step, +): + # Simulate retriever_tool is returned + retriever_tool = MagicMock(name="retriever_tool") + patch_get_retriever_tool.return_value = retriever_tool + + result = await researcher_node(mock_state_with_resources, mock_config) + + # Should call get_web_search_tool with correct max_search_results + patch_get_web_search_tool.assert_called_once_with(7) + # Should call get_retriever_tool with resources + patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"]) + # Should call _setup_and_execute_agent_step with retriever_tool first + args, kwargs = patch_setup_and_execute_agent_step.call_args + tools = args[3] + assert tools[0] == retriever_tool + assert patch_get_web_search_tool.return_value in tools + assert result == "RESEARCHER_RESULT" + + +@pytest.mark.asyncio +async def test_researcher_node_without_retriever_tool( + mock_state_with_resources, + mock_config, + patch_config_from_runnable_config, + patch_get_web_search_tool, + patch_crawl_tool, + patch_get_retriever_tool, + patch_setup_and_execute_agent_step, +): + # Simulate retriever_tool is None + patch_get_retriever_tool.return_value = None + + result = await researcher_node(mock_state_with_resources, mock_config) + + patch_get_web_search_tool.assert_called_once_with(7) + patch_get_retriever_tool.assert_called_once_with(["resource1", "resource2"]) + args, kwargs = patch_setup_and_execute_agent_step.call_args + tools = args[3] + # Should not include retriever_tool + assert all(getattr(t, "name", None) != "retriever_tool" for t in tools) + assert patch_get_web_search_tool.return_value in tools + assert result == "RESEARCHER_RESULT" + + +@pytest.mark.asyncio +async def test_researcher_node_without_resources( + mock_state_without_resources, + mock_config, + patch_config_from_runnable_config, + patch_get_web_search_tool, + patch_crawl_tool, + patch_get_retriever_tool, + patch_setup_and_execute_agent_step, +): + patch_get_retriever_tool.return_value = None + + result = await researcher_node(mock_state_without_resources, mock_config) + + patch_get_web_search_tool.assert_called_once_with(7) + patch_get_retriever_tool.assert_called_once_with([]) + args, kwargs = patch_setup_and_execute_agent_step.call_args + tools = args[3] + assert patch_get_web_search_tool.return_value in tools + assert result == "RESEARCHER_RESULT" diff --git a/tests/unit/graph/test_builder.py b/tests/unit/graph/test_builder.py new file mode 100644 index 0000000..fbbdaa7 --- /dev/null +++ b/tests/unit/graph/test_builder.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import pytest +from unittest.mock import MagicMock, patch +import importlib +import sys + +import src.graph.builder as builder_mod + + +@pytest.fixture +def mock_state(): + class Step: + def __init__(self, execution_res=None, step_type=None): + self.execution_res = execution_res + self.step_type = step_type + + class Plan: + def __init__(self, steps): + self.steps = steps + + return { + "Step": Step, + "Plan": Plan, + } + + +def test_continue_to_running_research_team_no_plan(mock_state): + state = {"current_plan": None} + assert builder_mod.continue_to_running_research_team(state) == "planner" + + +def test_continue_to_running_research_team_no_steps(mock_state): + state = {"current_plan": mock_state["Plan"](steps=[])} + assert builder_mod.continue_to_running_research_team(state) == "planner" + + +def test_continue_to_running_research_team_all_executed(mock_state): + Step = mock_state["Step"] + Plan = mock_state["Plan"] + steps = [Step(execution_res=True), Step(execution_res=True)] + state = {"current_plan": Plan(steps=steps)} + assert builder_mod.continue_to_running_research_team(state) == "planner" + + +def test_continue_to_running_research_team_next_researcher(mock_state): + Step = mock_state["Step"] + Plan = mock_state["Plan"] + steps = [ + Step(execution_res=True), + Step(execution_res=None, step_type=builder_mod.StepType.RESEARCH), + ] + state = {"current_plan": Plan(steps=steps)} + assert builder_mod.continue_to_running_research_team(state) == "researcher" + + +def test_continue_to_running_research_team_next_coder(mock_state): + Step = mock_state["Step"] + Plan = mock_state["Plan"] + steps = [ + Step(execution_res=True), + Step(execution_res=None, step_type=builder_mod.StepType.PROCESSING), + ] + state = {"current_plan": Plan(steps=steps)} + assert builder_mod.continue_to_running_research_team(state) == "coder" + + +def test_continue_to_running_research_team_default_planner(mock_state): + Step = mock_state["Step"] + Plan = mock_state["Plan"] + steps = [Step(execution_res=True), Step(execution_res=None, step_type=None)] + state = {"current_plan": Plan(steps=steps)} + assert builder_mod.continue_to_running_research_team(state) == "planner" + + +@patch("src.graph.builder.StateGraph") +def test_build_base_graph_adds_nodes_and_edges(MockStateGraph): + mock_builder = MagicMock() + MockStateGraph.return_value = mock_builder + + builder_mod._build_base_graph() + + # Check that all nodes and edges are added + assert mock_builder.add_edge.call_count >= 2 + assert mock_builder.add_node.call_count >= 8 + mock_builder.add_conditional_edges.assert_called_once() + + +@patch("src.graph.builder._build_base_graph") +@patch("src.graph.builder.MemorySaver") +def test_build_graph_with_memory_uses_memory(MockMemorySaver, mock_build_base_graph): + mock_builder = MagicMock() + mock_build_base_graph.return_value = mock_builder + mock_memory = MagicMock() + MockMemorySaver.return_value = mock_memory + + builder_mod.build_graph_with_memory() + + mock_builder.compile.assert_called_once_with(checkpointer=mock_memory) + + +@patch("src.graph.builder._build_base_graph") +def test_build_graph_without_memory(mock_build_base_graph): + mock_builder = MagicMock() + mock_build_base_graph.return_value = mock_builder + + builder_mod.build_graph() + + mock_builder.compile.assert_called_once_with() + + +def test_graph_is_compiled(): + # The graph object should be the result of build_graph() + with patch("src.graph.builder._build_base_graph") as mock_base: + mock_builder = MagicMock() + mock_base.return_value = mock_builder + mock_builder.compile.return_value = "compiled_graph" + # reload the module to re-run the graph assignment + importlib.reload(sys.modules["src.graph.builder"]) + assert builder_mod.graph is not None diff --git a/uv.lock b/uv.lock index 9809d48..de614af 100644 --- a/uv.lock +++ b/uv.lock @@ -403,6 +403,7 @@ dev = [ ] test = [ { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "pytest-cov" }, ] @@ -429,6 +430,7 @@ requires-dist = [ { name = "numpy", specifier = ">=2.2.3" }, { name = "pandas", specifier = ">=2.2.3" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" }, + { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=1.0.0" }, { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "readabilipy", specifier = ">=0.3.0" }, @@ -1692,6 +1694,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634, upload-time = "2025-03-02T12:54:52.069Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976 }, +] + [[package]] name = "pytest-cov" version = "6.0.0"