2025-05-20 14:25:35 +08:00
|
|
|
|
import json
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
|
|
|
|
|
|
# 在这里 mock 掉 get_llm_by_type,避免 ValueError
|
|
|
|
|
|
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
|
|
|
|
|
|
from langgraph.types import Command
|
|
|
|
|
|
from src.graph.nodes import background_investigation_node
|
|
|
|
|
|
from src.config import SearchEngine
|
|
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
|
|
|
|
|
|
|
|
# Mock data
|
|
|
|
|
|
MOCK_SEARCH_RESULTS = [
|
|
|
|
|
|
{"title": "Test Title 1", "content": "Test Content 1"},
|
|
|
|
|
|
{"title": "Test Title 2", "content": "Test Content 2"},
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def mock_state():
|
|
|
|
|
|
return {
|
|
|
|
|
|
"messages": [HumanMessage(content="test query")],
|
2025-06-11 11:10:02 +08:00
|
|
|
|
"research_topic": "test query",
|
2025-05-20 14:25:35 +08:00
|
|
|
|
"background_investigation_results": None,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def mock_configurable():
|
|
|
|
|
|
mock = MagicMock()
|
|
|
|
|
|
mock.max_search_results = 5
|
|
|
|
|
|
return mock
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def mock_config():
|
|
|
|
|
|
# 你可以根据实际需要返回一个 MagicMock 或 dict
|
|
|
|
|
|
return MagicMock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def patch_config_from_runnable_config(mock_configurable):
|
|
|
|
|
|
with patch(
|
|
|
|
|
|
"src.graph.nodes.Configuration.from_runnable_config",
|
|
|
|
|
|
return_value=mock_configurable,
|
|
|
|
|
|
):
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def mock_tavily_search():
|
|
|
|
|
|
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
|
|
|
|
|
|
instance = mock.return_value
|
|
|
|
|
|
instance.invoke.return_value = [
|
|
|
|
|
|
{"title": "Test Title 1", "content": "Test Content 1"},
|
|
|
|
|
|
{"title": "Test Title 2", "content": "Test Content 2"},
|
|
|
|
|
|
]
|
|
|
|
|
|
yield mock
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
|
def mock_web_search_tool():
|
|
|
|
|
|
with patch("src.graph.nodes.get_web_search_tool") as mock:
|
|
|
|
|
|
instance = mock.return_value
|
|
|
|
|
|
instance.invoke.return_value = [
|
|
|
|
|
|
{"title": "Test Title 1", "content": "Test Content 1"},
|
|
|
|
|
|
{"title": "Test Title 2", "content": "Test Content 2"},
|
|
|
|
|
|
]
|
|
|
|
|
|
yield mock
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-05-27 23:05:34 -07:00
|
|
|
|
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
|
2025-05-20 14:25:35 +08:00
|
|
|
|
def test_background_investigation_node_tavily(
|
|
|
|
|
|
mock_state,
|
|
|
|
|
|
mock_tavily_search,
|
|
|
|
|
|
mock_web_search_tool,
|
|
|
|
|
|
search_engine,
|
|
|
|
|
|
patch_config_from_runnable_config,
|
|
|
|
|
|
mock_config,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""Test background_investigation_node with Tavily search engine"""
|
|
|
|
|
|
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
|
|
|
|
|
|
result = background_investigation_node(mock_state, mock_config)
|
|
|
|
|
|
|
|
|
|
|
|
# Verify the result structure
|
2025-06-04 21:47:17 -07:00
|
|
|
|
assert isinstance(result, dict)
|
2025-05-20 14:25:35 +08:00
|
|
|
|
|
|
|
|
|
|
# Verify the update contains background_investigation_results
|
2025-06-04 21:47:17 -07:00
|
|
|
|
assert "background_investigation_results" in result
|
2025-05-20 14:25:35 +08:00
|
|
|
|
|
|
|
|
|
|
# Parse and verify the JSON content
|
2025-06-04 21:47:17 -07:00
|
|
|
|
results = result["background_investigation_results"]
|
2025-05-20 14:25:35 +08:00
|
|
|
|
|
2025-05-27 23:05:34 -07:00
|
|
|
|
if search_engine == SearchEngine.TAVILY.value:
|
|
|
|
|
|
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
|
2025-06-04 21:47:17 -07:00
|
|
|
|
assert (
|
|
|
|
|
|
results
|
|
|
|
|
|
== "## Test Title 1\n\nTest Content 1\n\n## Test Title 2\n\nTest Content 2"
|
|
|
|
|
|
)
|
2025-05-20 14:25:35 +08:00
|
|
|
|
else:
|
|
|
|
|
|
mock_web_search_tool.return_value.invoke.assert_called_once_with(
|
|
|
|
|
|
"test query"
|
|
|
|
|
|
)
|
2025-06-04 21:47:17 -07:00
|
|
|
|
assert len(json.loads(results)) == 2
|
2025-05-20 14:25:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_background_investigation_node_malformed_response(
|
|
|
|
|
|
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
|
|
|
|
|
|
):
|
|
|
|
|
|
"""Test background_investigation_node with malformed Tavily response"""
|
2025-05-27 23:05:34 -07:00
|
|
|
|
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
|
2025-05-20 14:25:35 +08:00
|
|
|
|
# Mock a malformed response
|
|
|
|
|
|
mock_tavily_search.return_value.invoke.return_value = "invalid response"
|
|
|
|
|
|
|
|
|
|
|
|
result = background_investigation_node(mock_state, mock_config)
|
|
|
|
|
|
|
|
|
|
|
|
# Verify the result structure
|
2025-06-04 21:47:17 -07:00
|
|
|
|
assert isinstance(result, dict)
|
2025-05-20 14:25:35 +08:00
|
|
|
|
|
|
|
|
|
|
# Verify the update contains background_investigation_results
|
2025-06-04 21:47:17 -07:00
|
|
|
|
assert "background_investigation_results" in result
|
2025-05-20 14:25:35 +08:00
|
|
|
|
|
|
|
|
|
|
# Parse and verify the JSON content
|
2025-06-04 21:47:17 -07:00
|
|
|
|
results = result["background_investigation_results"]
|
|
|
|
|
|
assert json.loads(results) is None
|