mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-29 00:34:47 +08:00
test(backend): add core logic unit tests for task/title/mcp (#936)
* test(backend): add core logic unit tests for task/title/mcp * test(backend): fix lint issues in client test modules --------- Co-authored-by: songyaolun <songyaolun@bytedance.com>
This commit is contained in:
@@ -10,8 +10,8 @@ import pytest
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage # noqa: F401
|
||||||
|
|
||||||
from src.client import DeerFlowClient
|
from src.client import DeerFlowClient
|
||||||
from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
|
|
||||||
from src.gateway.routers.mcp import McpConfigResponse
|
from src.gateway.routers.mcp import McpConfigResponse
|
||||||
|
from src.gateway.routers.memory import MemoryConfigResponse, MemoryStatusResponse
|
||||||
from src.gateway.routers.models import ModelResponse, ModelsListResponse
|
from src.gateway.routers.models import ModelResponse, ModelsListResponse
|
||||||
from src.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
|
from src.gateway.routers.skills import SkillInstallResponse, SkillResponse, SkillsListResponse
|
||||||
from src.gateway.routers.uploads import UploadResponse
|
from src.gateway.routers.uploads import UploadResponse
|
||||||
@@ -1011,8 +1011,6 @@ class TestScenarioAgentRecreation:
|
|||||||
|
|
||||||
def test_different_model_triggers_rebuild(self, client):
|
def test_different_model_triggers_rebuild(self, client):
|
||||||
"""Switching model_name between calls forces agent rebuild."""
|
"""Switching model_name between calls forces agent rebuild."""
|
||||||
mock_agent_1 = MagicMock(name="agent-v1")
|
|
||||||
mock_agent_2 = MagicMock(name="agent-v2")
|
|
||||||
agents_created = []
|
agents_created = []
|
||||||
|
|
||||||
def fake_create_agent(**kwargs):
|
def fake_create_agent(**kwargs):
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from src.client import DeerFlowClient, StreamEvent
|
||||||
|
|
||||||
# Skip entire module in CI or when no config.yaml exists
|
# Skip entire module in CI or when no config.yaml exists
|
||||||
_skip_reason = None
|
_skip_reason = None
|
||||||
if os.environ.get("CI"):
|
if os.environ.get("CI"):
|
||||||
@@ -22,9 +24,6 @@ elif not Path(__file__).resolve().parents[2].joinpath("config.yaml").exists():
|
|||||||
if _skip_reason:
|
if _skip_reason:
|
||||||
pytest.skip(_skip_reason, allow_module_level=True)
|
pytest.skip(_skip_reason, allow_module_level=True)
|
||||||
|
|
||||||
from src.client import DeerFlowClient, StreamEvent
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Fixtures
|
# Fixtures
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
93
backend/tests/test_mcp_client_config.py
Normal file
93
backend/tests/test_mcp_client_config.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Core behavior tests for MCP client server config building."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.config.extensions_config import ExtensionsConfig, McpServerConfig
|
||||||
|
from src.mcp.client import build_server_params, build_servers_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_server_params_stdio_success():
|
||||||
|
config = McpServerConfig(
|
||||||
|
type="stdio",
|
||||||
|
command="npx",
|
||||||
|
args=["-y", "my-mcp-server"],
|
||||||
|
env={"API_KEY": "secret"},
|
||||||
|
)
|
||||||
|
|
||||||
|
params = build_server_params("my-server", config)
|
||||||
|
|
||||||
|
assert params == {
|
||||||
|
"transport": "stdio",
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "my-mcp-server"],
|
||||||
|
"env": {"API_KEY": "secret"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_server_params_stdio_requires_command():
|
||||||
|
config = McpServerConfig(type="stdio", command=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="requires 'command' field"):
|
||||||
|
build_server_params("broken-stdio", config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("transport", ["sse", "http"])
|
||||||
|
def test_build_server_params_http_like_success(transport: str):
|
||||||
|
config = McpServerConfig(
|
||||||
|
type=transport,
|
||||||
|
url="https://example.com/mcp",
|
||||||
|
headers={"Authorization": "Bearer token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
params = build_server_params("remote-server", config)
|
||||||
|
|
||||||
|
assert params == {
|
||||||
|
"transport": transport,
|
||||||
|
"url": "https://example.com/mcp",
|
||||||
|
"headers": {"Authorization": "Bearer token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("transport", ["sse", "http"])
|
||||||
|
def test_build_server_params_http_like_requires_url(transport: str):
|
||||||
|
config = McpServerConfig(type=transport, url=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="requires 'url' field"):
|
||||||
|
build_server_params("broken-remote", config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_server_params_rejects_unsupported_transport():
|
||||||
|
config = McpServerConfig(type="websocket")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="unsupported transport type"):
|
||||||
|
build_server_params("bad-transport", config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_servers_config_returns_empty_when_no_enabled_servers():
|
||||||
|
extensions = ExtensionsConfig(
|
||||||
|
mcp_servers={
|
||||||
|
"disabled-a": McpServerConfig(enabled=False, type="stdio", command="echo"),
|
||||||
|
"disabled-b": McpServerConfig(enabled=False, type="http", url="https://example.com"),
|
||||||
|
},
|
||||||
|
skills={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert build_servers_config(extensions) == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_servers_config_skips_invalid_server_and_keeps_valid_ones():
|
||||||
|
extensions = ExtensionsConfig(
|
||||||
|
mcp_servers={
|
||||||
|
"valid-stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["server"]),
|
||||||
|
"invalid-stdio": McpServerConfig(enabled=True, type="stdio", command=None),
|
||||||
|
"disabled-http": McpServerConfig(enabled=False, type="http", url="https://disabled.example.com"),
|
||||||
|
},
|
||||||
|
skills={},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = build_servers_config(extensions)
|
||||||
|
|
||||||
|
assert "valid-stdio" in result
|
||||||
|
assert result["valid-stdio"]["transport"] == "stdio"
|
||||||
|
assert "invalid-stdio" not in result
|
||||||
|
assert "disabled-http" not in result
|
||||||
241
backend/tests/test_task_tool_core_logic.py
Normal file
241
backend/tests/test_task_tool_core_logic.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""Core behavior tests for task tool orchestration."""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from enum import Enum
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from src.subagents.config import SubagentConfig
|
||||||
|
|
||||||
|
# Use module import so tests can patch the exact symbols referenced inside task_tool().
|
||||||
|
task_tool_module = importlib.import_module("src.tools.builtins.task_tool")
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSubagentStatus(Enum):
|
||||||
|
# Match production enum values so branch comparisons behave identically.
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
TIMED_OUT = "timed_out"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_runtime() -> SimpleNamespace:
|
||||||
|
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
|
||||||
|
return SimpleNamespace(
|
||||||
|
state={
|
||||||
|
"sandbox": {"sandbox_id": "local"},
|
||||||
|
"thread_data": {
|
||||||
|
"workspace_path": "/tmp/workspace",
|
||||||
|
"uploads_path": "/tmp/uploads",
|
||||||
|
"outputs_path": "/tmp/outputs",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
context={"thread_id": "thread-1"},
|
||||||
|
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_subagent_config() -> SubagentConfig:
|
||||||
|
return SubagentConfig(
|
||||||
|
name="general-purpose",
|
||||||
|
description="General helper",
|
||||||
|
system_prompt="Base system prompt",
|
||||||
|
max_turns=50,
|
||||||
|
timeout_seconds=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_result(
|
||||||
|
status: FakeSubagentStatus,
|
||||||
|
*,
|
||||||
|
ai_messages: list[dict] | None = None,
|
||||||
|
result: str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
return SimpleNamespace(
|
||||||
|
status=status,
|
||||||
|
ai_messages=ai_messages or [],
|
||||||
|
result=result,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
|
||||||
|
|
||||||
|
result = task_tool_module.task_tool.func(
|
||||||
|
runtime=None,
|
||||||
|
description="执行任务",
|
||||||
|
prompt="do work",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.startswith("Error: Unknown subagent type")
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||||
|
config = _make_subagent_config()
|
||||||
|
runtime = _make_runtime()
|
||||||
|
events = []
|
||||||
|
captured = {}
|
||||||
|
get_available_tools = MagicMock(return_value=["tool-a", "tool-b"])
|
||||||
|
|
||||||
|
class DummyExecutor:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured["executor_kwargs"] = kwargs
|
||||||
|
|
||||||
|
def execute_async(self, prompt, task_id=None):
|
||||||
|
captured["prompt"] = prompt
|
||||||
|
captured["task_id"] = task_id
|
||||||
|
return task_id or "generated-task-id"
|
||||||
|
|
||||||
|
# Simulate two polling rounds: first running (with one message), then completed.
|
||||||
|
responses = iter(
|
||||||
|
[
|
||||||
|
_make_result(FakeSubagentStatus.RUNNING, ai_messages=[{"id": "m1", "content": "phase-1"}]),
|
||||||
|
_make_result(
|
||||||
|
FakeSubagentStatus.COMPLETED,
|
||||||
|
ai_messages=[{"id": "m1", "content": "phase-1"}, {"id": "m2", "content": "phase-2"}],
|
||||||
|
result="all done",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||||
|
# task_tool lazily imports from src.tools at call time, so patch that module-level function.
|
||||||
|
monkeypatch.setattr("src.tools.get_available_tools", get_available_tools)
|
||||||
|
|
||||||
|
output = task_tool_module.task_tool.func(
|
||||||
|
runtime=runtime,
|
||||||
|
description="运行子任务",
|
||||||
|
prompt="collect diagnostics",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-123",
|
||||||
|
max_turns=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output == "Task Succeeded. Result: all done"
|
||||||
|
assert captured["prompt"] == "collect diagnostics"
|
||||||
|
assert captured["task_id"] == "tc-123"
|
||||||
|
assert captured["executor_kwargs"]["thread_id"] == "thread-1"
|
||||||
|
assert captured["executor_kwargs"]["parent_model"] == "ark-model"
|
||||||
|
assert captured["executor_kwargs"]["config"].max_turns == 7
|
||||||
|
assert "Skills Appendix" in captured["executor_kwargs"]["config"].system_prompt
|
||||||
|
|
||||||
|
get_available_tools.assert_called_once_with(model_name="ark-model", subagent_enabled=False)
|
||||||
|
|
||||||
|
event_types = [e["type"] for e in events]
|
||||||
|
assert event_types == ["task_started", "task_running", "task_running", "task_completed"]
|
||||||
|
assert events[-1]["result"] == "all done"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_tool_returns_failed_message(monkeypatch):
|
||||||
|
config = _make_subagent_config()
|
||||||
|
events = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"SubagentExecutor",
|
||||||
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"get_background_task_result",
|
||||||
|
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||||
|
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
|
||||||
|
output = task_tool_module.task_tool.func(
|
||||||
|
runtime=_make_runtime(),
|
||||||
|
description="执行任务",
|
||||||
|
prompt="do fail",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-fail",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output == "Task failed. Error: subagent crashed"
|
||||||
|
assert events[-1]["type"] == "task_failed"
|
||||||
|
assert events[-1]["error"] == "subagent crashed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_tool_returns_timed_out_message(monkeypatch):
|
||||||
|
config = _make_subagent_config()
|
||||||
|
events = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"SubagentExecutor",
|
||||||
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"get_background_task_result",
|
||||||
|
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||||
|
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
|
||||||
|
output = task_tool_module.task_tool.func(
|
||||||
|
runtime=_make_runtime(),
|
||||||
|
description="执行任务",
|
||||||
|
prompt="do timeout",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-timeout",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output == "Task timed out. Error: timeout"
|
||||||
|
assert events[-1]["type"] == "task_timed_out"
|
||||||
|
assert events[-1]["error"] == "timeout"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_tool_polling_safety_timeout(monkeypatch):
|
||||||
|
config = _make_subagent_config()
|
||||||
|
# Keep max_poll_count small for test speed: (1 + 60) // 5 = 12
|
||||||
|
config.timeout_seconds = 1
|
||||||
|
events = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"SubagentExecutor",
|
||||||
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"get_background_task_result",
|
||||||
|
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
||||||
|
monkeypatch.setattr("src.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
|
||||||
|
output = task_tool_module.task_tool.func(
|
||||||
|
runtime=_make_runtime(),
|
||||||
|
description="执行任务",
|
||||||
|
prompt="never finish",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-safety-timeout",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output.startswith("Task polling timed out after 0 minutes")
|
||||||
|
assert events[0]["type"] == "task_started"
|
||||||
|
assert events[-1]["type"] == "task_timed_out"
|
||||||
123
backend/tests/test_title_middleware_core_logic.py
Normal file
123
backend/tests/test_title_middleware_core_logic.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Core behavior tests for TitleMiddleware."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
|
from src.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||||
|
|
||||||
|
|
||||||
|
def _clone_title_config(config: TitleConfig) -> TitleConfig:
|
||||||
|
# Avoid mutating shared global config objects across tests.
|
||||||
|
return TitleConfig(**config.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
def _set_test_title_config(**overrides) -> TitleConfig:
|
||||||
|
config = _clone_title_config(get_title_config())
|
||||||
|
for key, value in overrides.items():
|
||||||
|
setattr(config, key, value)
|
||||||
|
set_title_config(config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
class TestTitleMiddlewareCoreLogic:
|
||||||
|
def setup_method(self):
|
||||||
|
# Title config is a global singleton; snapshot and restore for test isolation.
|
||||||
|
self._original = _clone_title_config(get_title_config())
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
set_title_config(self._original)
|
||||||
|
|
||||||
|
def test_should_generate_title_for_first_complete_exchange(self):
|
||||||
|
_set_test_title_config(enabled=True)
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="帮我总结这段代码"),
|
||||||
|
AIMessage(content="好的,我先看结构"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
assert middleware._should_generate_title(state) is True
|
||||||
|
|
||||||
|
def test_should_not_generate_title_when_disabled_or_already_set(self):
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
|
||||||
|
_set_test_title_config(enabled=False)
|
||||||
|
disabled_state = {
|
||||||
|
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||||
|
"title": None,
|
||||||
|
}
|
||||||
|
assert middleware._should_generate_title(disabled_state) is False
|
||||||
|
|
||||||
|
_set_test_title_config(enabled=True)
|
||||||
|
titled_state = {
|
||||||
|
"messages": [HumanMessage(content="Q"), AIMessage(content="A")],
|
||||||
|
"title": "Existing Title",
|
||||||
|
}
|
||||||
|
assert middleware._should_generate_title(titled_state) is False
|
||||||
|
|
||||||
|
def test_should_not_generate_title_after_second_user_turn(self):
|
||||||
|
_set_test_title_config(enabled=True)
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="第一问"),
|
||||||
|
AIMessage(content="第一答"),
|
||||||
|
HumanMessage(content="第二问"),
|
||||||
|
AIMessage(content="第二答"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
assert middleware._should_generate_title(state) is False
|
||||||
|
|
||||||
|
def test_generate_title_trims_quotes_and_respects_max_chars(self, monkeypatch):
|
||||||
|
_set_test_title_config(max_chars=12)
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
fake_model = MagicMock()
|
||||||
|
fake_model.invoke.return_value = MagicMock(content='"A very long generated title"')
|
||||||
|
monkeypatch.setattr("src.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="请帮我写一个脚本"),
|
||||||
|
AIMessage(content="好的,先确认需求"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
title = middleware._generate_title(state)
|
||||||
|
|
||||||
|
assert '"' not in title
|
||||||
|
assert "'" not in title
|
||||||
|
assert len(title) == 12
|
||||||
|
|
||||||
|
def test_generate_title_fallback_when_model_fails(self, monkeypatch):
|
||||||
|
_set_test_title_config(max_chars=20)
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
fake_model = MagicMock()
|
||||||
|
fake_model.invoke.side_effect = RuntimeError("LLM unavailable")
|
||||||
|
monkeypatch.setattr("src.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
|
state = {
|
||||||
|
"messages": [
|
||||||
|
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题"),
|
||||||
|
AIMessage(content="收到"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
title = middleware._generate_title(state)
|
||||||
|
|
||||||
|
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
|
||||||
|
assert title.endswith("...")
|
||||||
|
assert title.startswith("这是一个非常长的问题描述")
|
||||||
|
|
||||||
|
def test_after_agent_returns_title_only_when_needed(self, monkeypatch):
|
||||||
|
middleware = TitleMiddleware()
|
||||||
|
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: True)
|
||||||
|
monkeypatch.setattr(middleware, "_generate_title", lambda state: "核心逻辑回归")
|
||||||
|
|
||||||
|
result = middleware.after_agent({"messages": []}, runtime=MagicMock())
|
||||||
|
|
||||||
|
assert result == {"title": "核心逻辑回归"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: False)
|
||||||
|
assert middleware.after_agent({"messages": []}, runtime=MagicMock()) is None
|
||||||
Reference in New Issue
Block a user