diff --git a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py index 8e33743..20cb02f 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py @@ -1,5 +1,6 @@ """Middleware for automatic thread title generation.""" +import logging from typing import NotRequired, override from langchain.agents import AgentState @@ -9,6 +10,8 @@ from langgraph.runtime import Runtime from deerflow.config.title_config import get_title_config from deerflow.models import create_chat_model +logger = logging.getLogger(__name__) + class TitleMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" @@ -62,49 +65,85 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): # Generate title after first complete exchange return len(user_messages) == 1 and len(assistant_messages) >= 1 - async def _generate_title(self, state: TitleMiddlewareState) -> str: - """Generate a concise title based on the conversation.""" + def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]: + """Extract user/assistant messages and build the title prompt. + + Returns (prompt_string, user_msg) so callers can use user_msg as fallback. + """ config = get_title_config() messages = state.get("messages", []) - # Get first user message and first assistant response user_msg_content = next((m.content for m in messages if m.type == "human"), "") assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "") user_msg = self._normalize_content(user_msg_content) assistant_msg = self._normalize_content(assistant_msg_content) - # Use a lightweight model to generate title - model = create_chat_model(thinking_enabled=False) - prompt = config.prompt_template.format( max_words=config.max_words, user_msg=user_msg[:500], assistant_msg=assistant_msg[:500], ) + return prompt, user_msg + + def _parse_title(self, content: object) -> str: + """Normalize model output into a clean title string.""" + config = get_title_config() + title_content = self._normalize_content(content) + title = title_content.strip().strip('"').strip("'") + return title[: config.max_chars] if len(title) > config.max_chars else title + + def _fallback_title(self, user_msg: str) -> str: + config = get_title_config() + fallback_chars = min(config.max_chars, 50) + if len(user_msg) > fallback_chars: + return user_msg[:fallback_chars].rstrip() + "..." + return user_msg if user_msg else "New Conversation" + + def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None: + """Synchronously generate a title. Returns state update or None.""" + if not self._should_generate_title(state): + return None + + prompt, user_msg = self._build_title_prompt(state) + config = get_title_config() + model = create_chat_model(name=config.model_name, thinking_enabled=False) + + try: + response = model.invoke(prompt) + title = self._parse_title(response.content) + if not title: + title = self._fallback_title(user_msg) + except Exception: + logger.exception("Failed to generate title (sync)") + title = self._fallback_title(user_msg) + + return {"title": title} + + async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None: + """Asynchronously generate a title. Returns state update or None.""" + if not self._should_generate_title(state): + return None + + prompt, user_msg = self._build_title_prompt(state) + config = get_title_config() + model = create_chat_model(name=config.model_name, thinking_enabled=False) try: response = await model.ainvoke(prompt) - title_content = self._normalize_content(response.content) - title = title_content.strip().strip('"').strip("'") - # Limit to max characters - return title[: config.max_chars] if len(title) > config.max_chars else title - except Exception as e: - print(f"Failed to generate title: {e}") - # Fallback: use first part of user message (by character count) - fallback_chars = min(config.max_chars, 50) # Use max_chars or 50, whichever is smaller - if len(user_msg) > fallback_chars: - return user_msg[:fallback_chars].rstrip() + "..." - return user_msg if user_msg else "New Conversation" + title = self._parse_title(response.content) + if not title: + title = self._fallback_title(user_msg) + except Exception: + logger.exception("Failed to generate title (async)") + title = self._fallback_title(user_msg) + + return {"title": title} + + @override + def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None: + return self._generate_title_result(state) @override async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None: - """Generate and set thread title after the first agent response.""" - if self._should_generate_title(state): - title = await self._generate_title(state) - print(f"Generated thread title: {title}") - - # Store title in state (will be persisted by checkpointer if configured) - return {"title": title} - - return None + return await self._agenerate_title_result(state) diff --git a/backend/tests/test_title_middleware_core_logic.py b/backend/tests/test_title_middleware_core_logic.py index 906e68c..f2552e3 100644 --- a/backend/tests/test_title_middleware_core_logic.py +++ b/backend/tests/test_title_middleware_core_logic.py @@ -86,7 +86,8 @@ class TestTitleMiddlewareCoreLogic: AIMessage(content="好的,先确认需求"), ] } - title = asyncio.run(middleware._generate_title(state)) + result = asyncio.run(middleware._agenerate_title_result(state)) + title = result["title"] assert '"' not in title assert "'" not in title @@ -111,7 +112,8 @@ class TestTitleMiddlewareCoreLogic: ] } - title = asyncio.run(middleware._generate_title(state)) + result = asyncio.run(middleware._agenerate_title_result(state)) + title = result["title"] prompt = fake_model.ainvoke.await_args.args[0] assert "请帮我总结这段代码" in prompt @@ -135,20 +137,64 @@ class TestTitleMiddlewareCoreLogic: AIMessage(content="收到"), ] } - title = asyncio.run(middleware._generate_title(state)) + result = asyncio.run(middleware._agenerate_title_result(state)) + title = result["title"] # Assert behavior (truncated fallback + ellipsis) without overfitting exact text. assert title.endswith("...") assert title.startswith("这是一个非常长的问题描述") - def test_after_agent_returns_title_only_when_needed(self, monkeypatch): + def test_aafter_model_delegates_to_async_helper(self, monkeypatch): middleware = TitleMiddleware() - monkeypatch.setattr(middleware, "_should_generate_title", lambda state: True) - monkeypatch.setattr(middleware, "_generate_title", AsyncMock(return_value="核心逻辑回归")) + monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"})) result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) + assert result == {"title": "异步标题"} - assert result == {"title": "核心逻辑回归"} - - monkeypatch.setattr(middleware, "_should_generate_title", lambda state: False) + monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None)) assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None + + def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch): + middleware = TitleMiddleware() + + monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"})) + result = middleware.after_model({"messages": []}, runtime=MagicMock()) + assert result == {"title": "同步标题"} + + monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None)) + assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None + + def test_sync_generate_title_with_model(self, monkeypatch): + """Sync path calls model.invoke and produces a title.""" + _set_test_title_config(max_chars=20) + middleware = TitleMiddleware() + fake_model = MagicMock() + fake_model.invoke = MagicMock(return_value=MagicMock(content='"同步生成的标题"')) + monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model) + + state = { + "messages": [ + HumanMessage(content="请帮我写测试"), + AIMessage(content="好的"), + ] + } + result = middleware._generate_title_result(state) + assert result == {"title": "同步生成的标题"} + fake_model.invoke.assert_called_once() + + def test_empty_title_falls_back(self, monkeypatch): + """Empty model response triggers fallback title.""" + _set_test_title_config(max_chars=50) + middleware = TitleMiddleware() + fake_model = MagicMock() + fake_model.invoke = MagicMock(return_value=MagicMock(content=" ")) + monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model) + + state = { + "messages": [ + HumanMessage(content="空标题测试"), + AIMessage(content="回复"), + ] + } + result = middleware._generate_title_result(state) + assert result["title"] == "空标题测试"