mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-15 11:04:44 +08:00
test(backend): add core logic unit tests for task/title/mcp (#936)
* test(backend): add core logic unit tests for task/title/mcp * test(backend): fix lint issues in client test modules --------- Co-authored-by: songyaolun <songyaolun@bytedance.com>
This commit is contained in:
@@ -10,8 +10,8 @@ import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401
|
||||
|
||||
from src.client import DeerFlowClient
|
||||
from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
|
||||
from src.gateway.routers.mcp import McpConfigResponse
|
||||
from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
|
||||
from src.gateway.routers.models import ModelResponse, ModelsListResponse
|
||||
from src.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
|
||||
from src.gateway.routers.uploads import UploadResponse
|
||||
@@ -1011,8 +1011,6 @@ class TestScenarioAgentRecreation:
|
||||
|
||||
def test_different_model_triggers_rebuild(self, client):
|
||||
"""Switching model_name between calls forces agent rebuild."""
|
||||
mock_agent_1 = MagicMock(name="agent-v1")
|
||||
mock_agent_2 = MagicMock(name="agent-v2")
|
||||
agents_created = []
|
||||
|
||||
def fake_create_agent(**kwargs):
|
||||
|
||||
@@ -12,6 +12,8 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.client import DeerFlowClient, StreamEvent
|
||||
|
||||
# Skip entire module in CI or when no config.yaml exists
|
||||
_skip_reason = None
|
||||
if os.environ.get("CI"):
|
||||
@@ -22,9 +24,6 @@ elif not Path(__file__).resolve().parents[2].joinpath("config.yaml").exists():
|
||||
if _skip_reason:
|
||||
pytest.skip(_skip_reason, allow_module_level=True)
|
||||
|
||||
from src.client import DeerFlowClient, StreamEvent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
93
backend/tests/test_mcp_client_config.py
Normal file
93
backend/tests/test_mcp_client_config.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Core behavior tests for MCP client server config building."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||
from src.mcp.client import build_server_params, build_servers_config
|
||||
|
||||
|
||||
def test_build_server_params_stdio_success():
|
||||
config = McpServerConfig(
|
||||
type="stdio",
|
||||
command="npx",
|
||||
args=["-y", "my-mcp-server"],
|
||||
env={"API_KEY": "secret"},
|
||||
)
|
||||
|
||||
params = build_server_params("my-server", config)
|
||||
|
||||
assert params == {
|
||||
"transport": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["-y", "my-mcp-server"],
|
||||
"env": {"API_KEY": "secret"},
|
||||
}
|
||||
|
||||
|
||||
def test_build_server_params_stdio_requires_command():
|
||||
config = McpServerConfig(type="stdio", command=None)
|
||||
|
||||
with pytest.raises(ValueError, match="requires 'command' field"):
|
||||
build_server_params("broken-stdio", config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("transport", ["sse", "http"])
|
||||
def test_build_server_params_http_like_success(transport: str):
|
||||
config = McpServerConfig(
|
||||
type=transport,
|
||||
url="https://example.com/mcp",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
)
|
||||
|
||||
params = build_server_params("remote-server", config)
|
||||
|
||||
assert params == {
|
||||
"transport": transport,
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("transport", ["sse", "http"])
|
||||
def test_build_server_params_http_like_requires_url(transport: str):
|
||||
config = McpServerConfig(type=transport, url=None)
|
||||
|
||||
with pytest.raises(ValueError, match="requires 'url' field"):
|
||||
build_server_params("broken-remote", config)
|
||||
|
||||
|
||||
def test_build_server_params_rejects_unsupported_transport():
|
||||
config = McpServerConfig(type="websocket")
|
||||
|
||||
with pytest.raises(ValueError, match="unsupported transport type"):
|
||||
build_server_params("bad-transport", config)
|
||||
|
||||
|
||||
def test_build_servers_config_returns_empty_when_no_enabled_servers():
|
||||
extensions = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"disabled-a": McpServerConfig(enabled=False, type="stdio", command="echo"),
|
||||
"disabled-b": McpServerConfig(enabled=False, type="http", url="https://example.com"),
|
||||
},
|
||||
skills={},
|
||||
)
|
||||
|
||||
assert build_servers_config(extensions) == {}
|
||||
|
||||
|
||||
def test_build_servers_config_skips_invalid_server_and_keeps_valid_ones():
|
||||
extensions = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"valid-stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["server"]),
|
||||
"invalid-stdio": McpServerConfig(enabled=True, type="stdio", command=None),
|
||||
"disabled-http": McpServerConfig(enabled=False, type="http", url="https://disabled.example.com"),
|
||||
},
|
||||
skills={},
|
||||
)
|
||||
|
||||
result = build_servers_config(extensions)
|
||||
|
||||
assert "valid-stdio" in result
|
||||
assert result["valid-stdio"]["transport"] == "stdio"
|
||||
assert "invalid-stdio" not in result
|
||||
assert "disabled-http" not in result
|
||||
241
backend/tests/test_task_tool_core_logic.py
Normal file
241
backend/tests/test_task_tool_core_logic.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Core behavior tests for task tool orchestration."""
|
||||
|
||||
import importlib
|
||||
from enum import Enum
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from src.subagents.config import SubagentConfig
|
||||
|
||||
# Use module import so tests can patch the exact symbols referenced inside task_tool().
|
||||
task_tool_module = importlib.import_module("src.tools.builtins.task_tool")
|
||||
|
||||
|
||||
class FakeSubagentStatus(Enum):
|
||||
# Match production enum values so branch comparisons behave identically.
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
|
||||
def _make_runtime() -> SimpleNamespace:
|
||||
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
|
||||
return SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
"thread_data": {
|
||||
"workspace_path": "/tmp/workspace",
|
||||
"uploads_path": "/tmp/uploads",
|
||||
"outputs_path": "/tmp/outputs",
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
|
||||
)
|
||||
|
||||
|
||||
def _make_subagent_config() -> SubagentConfig:
|
||||
return SubagentConfig(
|
||||
name="general-purpose",
|
||||
description="General helper",
|
||||
system_prompt="Base system prompt",
|
||||
max_turns=50,
|
||||
timeout_seconds=10,
|
||||
)
|
||||
|
||||
|
||||
def _make_result(
|
||||
status: FakeSubagentStatus,
|
||||
*,
|
||||
ai_messages: list[dict] | None = None,
|
||||
result: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
status=status,
|
||||
ai_messages=ai_messages or [],
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
|
||||
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
|
||||
|
||||
result = task_tool_module.task_tool.func(
|
||||
runtime=None,
|
||||
description="执行任务",
|
||||
prompt="do work",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-1",
|
||||
)
|
||||
|
||||
assert result.startswith("Error: Unknown subagent type")
|
||||
|
||||
|
||||
def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
events = []
|
||||
captured = {}
|
||||
get_available_tools = MagicMock(return_value=["tool-a", "tool-b"])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
captured["executor_kwargs"] = kwargs
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
captured["prompt"] = prompt
|
||||
captured["task_id"] = task_id
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
# Simulate two polling rounds: first running (with one message), then completed.
|
||||
responses = iter(
|
||||
[
|
||||
_make_result(FakeSubagentStatus.RUNNING, ai_messages=[{"id": "m1", "content": "phase-1"}]),
|
||||
_make_result(
|
||||
FakeSubagentStatus.COMPLETED,
|
||||
ai_messages=[{"id": "m1", "content": "phase-1"}, {"id": "m2", "content": "phase-2"}],
|
||||
result="all done",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
|
||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||
# task_tool lazily imports from src.tools at call time, so patch that module-level function.
|
||||
monkeypatch.setattr("src.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = task_tool_module.task_tool.func(
|
||||
runtime=runtime,
|
||||
description="运行子任务",
|
||||
prompt="collect diagnostics",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-123",
|
||||
max_turns=7,
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: all done"
|
||||
assert captured["prompt"] == "collect diagnostics"
|
||||
assert captured["task_id"] == "tc-123"
|
||||
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
|
||||
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
|
||||
assert captured["executor_kwargs"]["config"].max_turns == 7
|
||||
assert "Skills Appendix" in captured["executor_kwargs"]["config"].system_prompt
|
||||
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False)
|
||||
|
||||
event_types = [e["type"] for e in events]
|
||||
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
|
||||
assert events[-1]["result"] == "all done"
|
||||
|
||||
|
||||
def test_task_tool_returns_failed_message(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
|
||||
|
||||
output = task_tool_module.task_tool.func(
|
||||
runtime=_make_runtime(),
|
||||
description="执行任务",
|
||||
prompt="do fail",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-fail",
|
||||
)
|
||||
|
||||
assert output == "Task failed. Error: subagent crashed"
|
||||
assert events[-1]["type"] == "task_failed"
|
||||
assert events[-1]["error"] == "subagent crashed"
|
||||
|
||||
|
||||
def test_task_tool_returns_timed_out_message(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
|
||||
|
||||
output = task_tool_module.task_tool.func(
|
||||
runtime=_make_runtime(),
|
||||
description="执行任务",
|
||||
prompt="do timeout",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-timeout",
|
||||
)
|
||||
|
||||
assert output == "Task timed out. Error: timeout"
|
||||
assert events[-1]["type"] == "task_timed_out"
|
||||
assert events[-1]["error"] == "timeout"
|
||||
|
||||
|
||||
def test_task_tool_polling_safety_timeout(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
# Keep max_poll_count small for test speed: (1 + 60) // 5 = 12
|
||||
config.timeout_seconds = 1
|
||||
events = []
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"SubagentExecutor",
|
||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
|
||||
|
||||
output = task_tool_module.task_tool.func(
|
||||
runtime=_make_runtime(),
|
||||
description="执行任务",
|
||||
prompt="never finish",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-safety-timeout",
|
||||
)
|
||||
|
||||
assert output.startswith("Task polling timed out after 0 minutes")
|
||||
assert events[0]["type"] == "task_started"
|
||||
assert events[-1]["type"] == "task_timed_out"
|
||||
123
backend/tests/test_title_middleware_core_logic.py
Normal file
123
backend/tests/test_title_middleware_core_logic.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Core behavior tests for TitleMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||
|
||||
|
||||
def _clone_title_config(config: TitleConfig) -> TitleConfig:
|
||||
# Avoid mutating shared global config objects across tests.
|
||||
return TitleConfig(**config.model_dump())
|
||||
|
||||
|
||||
def _set_test_title_config(**overrides) -> TitleConfig:
|
||||
config = _clone_title_config(get_title_config())
|
||||
for key, value in overrides.items():
|
||||
setattr(config, key, value)
|
||||
set_title_config(config)
|
||||
return config
|
||||
|
||||
|
||||
class TestTitleMiddlewareCoreLogic:
|
||||
def setup_method(self):
|
||||
# Title config is a global singleton; snapshot and restore for test isolation.
|
||||
self._original = _clone_title_config(get_title_config())
|
||||
|
||||
def teardown_method(self):
|
||||
set_title_config(self._original)
|
||||
|
||||
def test_should_generate_title_for_first_complete_exchange(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="帮我总结这段代码"),
|
||||
AIMessage(content="好的,我先看结构"),
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is True
|
||||
|
||||
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
_set_test_title_config(enabled=False)
|
||||
disabled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": None,
|
||||
}
|
||||
assert middleware._should_generate_title(disabled_state) is False
|
||||
|
||||
_set_test_title_config(enabled=True)
|
||||
titled_state = {
|
||||
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||
"title": "Existing Title",
|
||||
}
|
||||
assert middleware._should_generate_title(titled_state) is False
|
||||
|
||||
def test_should_not_generate_title_after_second_user_turn(self):
|
||||
_set_test_title_config(enabled=True)
|
||||
middleware = TitleMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="第一问"),
|
||||
AIMessage(content="第一答"),
|
||||
HumanMessage(content="第二问"),
|
||||
AIMessage(content="第二答"),
|
||||
]
|
||||
}
|
||||
|
||||
assert middleware._should_generate_title(state) is False
|
||||
|
||||
def test_generate_title_trims_quotes_and_respects_max_chars(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=12)
|
||||
middleware = TitleMiddleware()
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.return_value = MagicMock(content='"A very long generated title"')
|
||||
monkeypatch.setattr("src.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="请帮我写一个脚本"),
|
||||
AIMessage(content="好的,先确认需求"),
|
||||
]
|
||||
}
|
||||
title = middleware._generate_title(state)
|
||||
|
||||
assert '"' not in title
|
||||
assert "'" not in title
|
||||
assert len(title) == 12
|
||||
|
||||
def test_generate_title_fallback_when_model_fails(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.side_effect = RuntimeError("LLM unavailable")
|
||||
monkeypatch.setattr("src.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
|
||||
AIMessage(content="收到"),
|
||||
]
|
||||
}
|
||||
title = middleware._generate_title(state)
|
||||
|
||||
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||
assert title.endswith("...")
|
||||
assert title.startswith("这是一个非常长的问题描述")
|
||||
|
||||
def test_after_agent_returns_title_only_when_needed(self, monkeypatch):
|
||||
middleware = TitleMiddleware()
|
||||
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: True)
|
||||
monkeypatch.setattr(middleware, "_generate_title", lambda state: "核心逻辑回归")
|
||||
|
||||
result = middleware.after_agent({"messages": []}, runtime=MagicMock())
|
||||
|
||||
assert result == {"title": "核心逻辑回归"}
|
||||
|
||||
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: False)
|
||||
assert middleware.after_agent({"messages": []}, runtime=MagicMock()) is None
|
||||
Reference in New Issue
Block a user