mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
129 lines
4.1 KiB
Python
129 lines
4.1 KiB
Python
import json
|
||
import pytest
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
# 在这里 mock 掉 get_llm_by_type,避免 ValueError
|
||
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
|
||
from langgraph.types import Command
|
||
from src.graph.nodes import background_investigation_node
|
||
from src.config import SearchEngine
|
||
from langchain_core.messages import HumanMessage
|
||
|
||
# Mock data
|
||
MOCK_SEARCH_RESULTS = [
|
||
{"title": "Test Title 1", "content": "Test Content 1"},
|
||
{"title": "Test Title 2", "content": "Test Content 2"},
|
||
]
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_state():
|
||
return {
|
||
"messages": [HumanMessage(content="test query")],
|
||
"background_investigation_results": None,
|
||
}
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_configurable():
|
||
mock = MagicMock()
|
||
mock.max_search_results = 5
|
||
return mock
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_config():
|
||
# 你可以根据实际需要返回一个 MagicMock 或 dict
|
||
return MagicMock()
|
||
|
||
|
||
@pytest.fixture
|
||
def patch_config_from_runnable_config(mock_configurable):
|
||
with patch(
|
||
"src.graph.nodes.Configuration.from_runnable_config",
|
||
return_value=mock_configurable,
|
||
):
|
||
yield
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_tavily_search():
|
||
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
|
||
instance = mock.return_value
|
||
instance.invoke.return_value = [
|
||
{"title": "Test Title 1", "content": "Test Content 1"},
|
||
{"title": "Test Title 2", "content": "Test Content 2"},
|
||
]
|
||
yield mock
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_web_search_tool():
|
||
with patch("src.graph.nodes.get_web_search_tool") as mock:
|
||
instance = mock.return_value
|
||
instance.invoke.return_value = [
|
||
{"title": "Test Title 1", "content": "Test Content 1"},
|
||
{"title": "Test Title 2", "content": "Test Content 2"},
|
||
]
|
||
yield mock
|
||
|
||
|
||
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY.value, "other"])
|
||
def test_background_investigation_node_tavily(
|
||
mock_state,
|
||
mock_tavily_search,
|
||
mock_web_search_tool,
|
||
search_engine,
|
||
patch_config_from_runnable_config,
|
||
mock_config,
|
||
):
|
||
"""Test background_investigation_node with Tavily search engine"""
|
||
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
|
||
result = background_investigation_node(mock_state, mock_config)
|
||
|
||
# Verify the result structure
|
||
assert isinstance(result, Command)
|
||
assert result.goto == "planner"
|
||
|
||
# Verify the update contains background_investigation_results
|
||
update = result.update
|
||
assert "background_investigation_results" in update
|
||
|
||
# Parse and verify the JSON content
|
||
results = json.loads(update["background_investigation_results"])
|
||
assert isinstance(results, list)
|
||
|
||
if search_engine == SearchEngine.TAVILY.value:
|
||
mock_tavily_search.return_value.invoke.assert_called_once_with("test query")
|
||
assert len(results) == 2
|
||
assert results[0]["title"] == "Test Title 1"
|
||
assert results[0]["content"] == "Test Content 1"
|
||
else:
|
||
mock_web_search_tool.return_value.invoke.assert_called_once_with(
|
||
"test query"
|
||
)
|
||
assert len(results) == 2
|
||
|
||
|
||
def test_background_investigation_node_malformed_response(
|
||
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
|
||
):
|
||
"""Test background_investigation_node with malformed Tavily response"""
|
||
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value):
|
||
# Mock a malformed response
|
||
mock_tavily_search.return_value.invoke.return_value = "invalid response"
|
||
|
||
result = background_investigation_node(mock_state, mock_config)
|
||
|
||
# Verify the result structure
|
||
assert isinstance(result, Command)
|
||
assert result.goto == "planner"
|
||
|
||
# Verify the update contains background_investigation_results
|
||
update = result.update
|
||
assert "background_investigation_results" in update
|
||
|
||
# Parse and verify the JSON content
|
||
results = json.loads(update["background_investigation_results"])
|
||
assert results is None
|