test(backend): add core logic unit tests for task/title/mcp (#936)

* test(backend): add core logic unit tests for task/title/mcp

* test(backend): fix lint issues in client test modules

---------

Co-authored-by: songyaolun <songyaolun@bytedance.com>
This commit is contained in:
YolenSong
2026-03-01 12:36:09 +08:00
committed by GitHub
parent f2123efdb9
commit 3d3ea84a57
5 changed files with 460 additions and 6 deletions

View File

@@ -10,8 +10,8 @@ import pytest
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401
from src.client import DeerFlowClient
from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
from src.gateway.routers.mcp import McpConfigResponse
from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
from src.gateway.routers.models import ModelResponse, ModelsListResponse
from src.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
from src.gateway.routers.uploads import UploadResponse
@@ -1011,8 +1011,6 @@ class TestScenarioAgentRecreation:
def test_different_model_triggers_rebuild(self, client):
"""Switching model_name between calls forces agent rebuild."""
mock_agent_1 = MagicMock(name="agent-v1")
mock_agent_2 = MagicMock(name="agent-v2")
agents_created = []
def fake_create_agent(**kwargs):

View File

@@ -12,6 +12,8 @@ from pathlib import Path
import pytest
from src.client import DeerFlowClient, StreamEvent
# Skip entire module in CI or when no config.yaml exists
_skip_reason = None
if os.environ.get("CI"):
@@ -22,9 +24,6 @@ elif not Path(__file__).resolve().parents[2].joinpath("config.yaml").exists():
if _skip_reason:
pytest.skip(_skip_reason, allow_module_level=True)
from src.client import DeerFlowClient, StreamEvent
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,93 @@
"""Core behavior tests for MCP client server config building."""
import pytest
from src.config.extensions_config import ExtensionsConfig, McpServerConfig
from src.mcp.client import build_server_params, build_servers_config
def test_build_server_params_stdio_success():
config = McpServerConfig(
type="stdio",
command="npx",
args=["-y", "my-mcp-server"],
env={"API_KEY": "secret"},
)
params = build_server_params("my-server", config)
assert params == {
"transport": "stdio",
"command": "npx",
"args": ["-y", "my-mcp-server"],
"env": {"API_KEY": "secret"},
}
def test_build_server_params_stdio_requires_command():
config = McpServerConfig(type="stdio", command=None)
with pytest.raises(ValueError, match="requires 'command' field"):
build_server_params("broken-stdio", config)
@pytest.mark.parametrize("transport", ["sse", "http"])
def test_build_server_params_http_like_success(transport: str):
config = McpServerConfig(
type=transport,
url="https://example.com/mcp",
headers={"Authorization": "Bearer token"},
)
params = build_server_params("remote-server", config)
assert params == {
"transport": transport,
"url": "https://example.com/mcp",
"headers": {"Authorization": "Bearer token"},
}
@pytest.mark.parametrize("transport", ["sse", "http"])
def test_build_server_params_http_like_requires_url(transport: str):
config = McpServerConfig(type=transport, url=None)
with pytest.raises(ValueError, match="requires 'url' field"):
build_server_params("broken-remote", config)
def test_build_server_params_rejects_unsupported_transport():
config = McpServerConfig(type="websocket")
with pytest.raises(ValueError, match="unsupported transport type"):
build_server_params("bad-transport", config)
def test_build_servers_config_returns_empty_when_no_enabled_servers():
extensions = ExtensionsConfig(
mcp_servers={
"disabled-a": McpServerConfig(enabled=False, type="stdio", command="echo"),
"disabled-b": McpServerConfig(enabled=False, type="http", url="https://example.com"),
},
skills={},
)
assert build_servers_config(extensions) == {}
def test_build_servers_config_skips_invalid_server_and_keeps_valid_ones():
extensions = ExtensionsConfig(
mcp_servers={
"valid-stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["server"]),
"invalid-stdio": McpServerConfig(enabled=True, type="stdio", command=None),
"disabled-http": McpServerConfig(enabled=False, type="http", url="https://disabled.example.com"),
},
skills={},
)
result = build_servers_config(extensions)
assert "valid-stdio" in result
assert result["valid-stdio"]["transport"] == "stdio"
assert "invalid-stdio" not in result
assert "disabled-http" not in result

View File

@@ -0,0 +1,241 @@
"""Core behavior tests for task tool orchestration."""
import importlib
from enum import Enum
from types import SimpleNamespace
from unittest.mock import MagicMock
from src.subagents.config import SubagentConfig
# Use module import so tests can patch the exact symbols referenced inside task_tool().
task_tool_module = importlib.import_module("src.tools.builtins.task_tool")
class FakeSubagentStatus(Enum):
# Match production enum values so branch comparisons behave identically.
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
TIMED_OUT = "timed_out"
def _make_runtime() -> SimpleNamespace:
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
return SimpleNamespace(
state={
"sandbox": {"sandbox_id": "local"},
"thread_data": {
"workspace_path": "/tmp/workspace",
"uploads_path": "/tmp/uploads",
"outputs_path": "/tmp/outputs",
},
},
context={"thread_id": "thread-1"},
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
)
def _make_subagent_config() -> SubagentConfig:
return SubagentConfig(
name="general-purpose",
description="General helper",
system_prompt="Base system prompt",
max_turns=50,
timeout_seconds=10,
)
def _make_result(
status: FakeSubagentStatus,
*,
ai_messages: list[dict] | None = None,
result: str | None = None,
error: str | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
status=status,
ai_messages=ai_messages or [],
result=result,
error=error,
)
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
result = task_tool_module.task_tool.func(
runtime=None,
description="执行任务",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-1",
)
assert result.startswith("Error: Unknown subagent type")
def test_task_tool_emits_running_and_completed_events(monkeypatch):
config = _make_subagent_config()
runtime = _make_runtime()
events = []
captured = {}
get_available_tools = MagicMock(return_value=["tool-a", "tool-b"])
class DummyExecutor:
def __init__(self, **kwargs):
captured["executor_kwargs"] = kwargs
def execute_async(self, prompt, task_id=None):
captured["prompt"] = prompt
captured["task_id"] = task_id
return task_id or "generated-task-id"
# Simulate two polling rounds: first running (with one message), then completed.
responses = iter(
[
_make_result(FakeSubagentStatus.RUNNING, ai_messages=[{"id": "m1", "content": "phase-1"}]),
_make_result(
FakeSubagentStatus.COMPLETED,
ai_messages=[{"id": "m1", "content": "phase-1"}, {"id": "m2", "content": "phase-2"}],
result="all done",
),
]
)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
# task_tool lazily imports from src.tools at call time, so patch that module-level function.
monkeypatch.setattr("src.tools.get_available_tools", get_available_tools)
output = task_tool_module.task_tool.func(
runtime=runtime,
description="运行子任务",
prompt="collect diagnostics",
subagent_type="general-purpose",
tool_call_id="tc-123",
max_turns=7,
)
assert output == "Task Succeeded. Result: all done"
assert captured["prompt"] == "collect diagnostics"
assert captured["task_id"] == "tc-123"
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
assert captured["executor_kwargs"]["config"].max_turns == 7
assert "Skills Appendix" in captured["executor_kwargs"]["config"].system_prompt
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False)
event_types = [e["type"] for e in events]
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
assert events[-1]["result"] == "all done"
def test_task_tool_returns_failed_message(monkeypatch):
config = _make_subagent_config()
events = []
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
output = task_tool_module.task_tool.func(
runtime=_make_runtime(),
description="执行任务",
prompt="do fail",
subagent_type="general-purpose",
tool_call_id="tc-fail",
)
assert output == "Task failed. Error: subagent crashed"
assert events[-1]["type"] == "task_failed"
assert events[-1]["error"] == "subagent crashed"
def test_task_tool_returns_timed_out_message(monkeypatch):
config = _make_subagent_config()
events = []
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
output = task_tool_module.task_tool.func(
runtime=_make_runtime(),
description="执行任务",
prompt="do timeout",
subagent_type="general-purpose",
tool_call_id="tc-timeout",
)
assert output == "Task timed out. Error: timeout"
assert events[-1]["type"] == "task_timed_out"
assert events[-1]["error"] == "timeout"
def test_task_tool_polling_safety_timeout(monkeypatch):
config = _make_subagent_config()
# Keep max_poll_count small for test speed: (1 + 60) // 5 = 12
config.timeout_seconds = 1
events = []
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
output = task_tool_module.task_tool.func(
runtime=_make_runtime(),
description="执行任务",
prompt="never finish",
subagent_type="general-purpose",
tool_call_id="tc-safety-timeout",
)
assert output.startswith("Task polling timed out after 0 minutes")
assert events[0]["type"] == "task_started"
assert events[-1]["type"] == "task_timed_out"

View File

@@ -0,0 +1,123 @@
"""Core behavior tests for TitleMiddleware."""
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from src.agents.middlewares.title_middleware import TitleMiddleware
from src.config.title_config import TitleConfig, get_title_config, set_title_config
def _clone_title_config(config: TitleConfig) -> TitleConfig:
# Avoid mutating shared global config objects across tests.
return TitleConfig(**config.model_dump())
def _set_test_title_config(**overrides) -> TitleConfig:
config = _clone_title_config(get_title_config())
for key, value in overrides.items():
setattr(config, key, value)
set_title_config(config)
return config
class TestTitleMiddlewareCoreLogic:
def setup_method(self):
# Title config is a global singleton; snapshot and restore for test isolation.
self._original = _clone_title_config(get_title_config())
def teardown_method(self):
set_title_config(self._original)
def test_should_generate_title_for_first_complete_exchange(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="帮我总结这段代码"),
AIMessage(content="好的,我先看结构"),
]
}
assert middleware._should_generate_title(state) is True
def test_should_not_generate_title_when_disabled_or_already_set(self):
middleware = TitleMiddleware()
_set_test_title_config(enabled=False)
disabled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": None,
}
assert middleware._should_generate_title(disabled_state) is False
_set_test_title_config(enabled=True)
titled_state = {
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
"title": "Existing Title",
}
assert middleware._should_generate_title(titled_state) is False
def test_should_not_generate_title_after_second_user_turn(self):
_set_test_title_config(enabled=True)
middleware = TitleMiddleware()
state = {
"messages": [
HumanMessage(content="第一问"),
AIMessage(content="第一答"),
HumanMessage(content="第二问"),
AIMessage(content="第二答"),
]
}
assert middleware._should_generate_title(state) is False
def test_generate_title_trims_quotes_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content='"A very long generated title"')
monkeypatch.setattr("src.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="请帮我写一个脚本"),
AIMessage(content="好的,先确认需求"),
]
}
title = middleware._generate_title(state)
assert '"' not in title
assert "'" not in title
assert len(title) == 12
def test_generate_title_fallback_when_model_fails(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke.side_effect = RuntimeError("LLM unavailable")
monkeypatch.setattr("src.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="这是一个非常长的问题描述需要被截断以形成fallback标题"),
AIMessage(content="收到"),
]
}
title = middleware._generate_title(state)
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述")
def test_after_agent_returns_title_only_when_needed(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: True)
monkeypatch.setattr(middleware, "_generate_title", lambda state: "核心逻辑回归")
result = middleware.after_agent({"messages": []}, runtime=MagicMock())
assert result == {"title": "核心逻辑回归"}
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: False)
assert middleware.after_agent({"messages": []}, runtime=MagicMock()) is None