fix: add sync after_model to TitleMiddleware (#1190)

This commit is contained in:
greatmengqi
2026-03-19 15:46:31 +08:00
committed by GitHub
parent f67c3d2c9e
commit accf5b5f8e
2 changed files with 120 additions and 35 deletions

View File

@@ -1,5 +1,6 @@
"""Middleware for automatic thread title generation."""
import logging
from typing import NotRequired, override
from langchain.agents import AgentState
@@ -9,6 +10,8 @@ from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
class TitleMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@@ -62,49 +65,85 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
# Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1
async def _generate_title(self, state: TitleMiddlewareState) -> str:
"""Generate a concise title based on the conversation."""
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
"""Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config()
messages = state.get("messages", [])
# Get first user message and first assistant response
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._normalize_content(assistant_msg_content)
# Use a lightweight model to generate title
model = create_chat_model(thinking_enabled=False)
prompt = config.prompt_template.format(
max_words=config.max_words,
user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500],
)
return prompt, user_msg
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (sync)")
title = self._fallback_title(user_msg)
return {"title": title}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Asynchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = await model.ainvoke(prompt)
title_content = self._normalize_content(response.content)
title = title_content.strip().strip('"').strip("'")
# Limit to max characters
return title[: config.max_chars] if len(title) > config.max_chars else title
except Exception as e:
print(f"Failed to generate title: {e}")
# Fallback: use first part of user message (by character count)
fallback_chars = min(config.max_chars, 50) # Use max_chars or 50, whichever is smaller
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (async)")
title = self._fallback_title(user_msg)
return {"title": title}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return self._generate_title_result(state)
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
"""Generate and set thread title after the first agent response."""
if self._should_generate_title(state):
title = await self._generate_title(state)
print(f"Generated thread title: {title}")
# Store title in state (will be persisted by checkpointer if configured)
return {"title": title}
return None
return await self._agenerate_title_result(state)

View File

@@ -86,7 +86,8 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="好的,先确认需求"),
]
}
title = asyncio.run(middleware._generate_title(state))
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
assert '"' not in title
assert "'" not in title
@@ -111,7 +112,8 @@ class TestTitleMiddlewareCoreLogic:
]
}
title = asyncio.run(middleware._generate_title(state))
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
prompt = fake_model.ainvoke.await_args.args[0]
assert "请帮我总结这段代码" in prompt
@@ -135,20 +137,64 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="收到"),
]
}
title = asyncio.run(middleware._generate_title(state))
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述")
def test_after_agent_returns_title_only_when_needed(self, monkeypatch):
def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: True)
monkeypatch.setattr(middleware, "_generate_title", AsyncMock(return_value="核心逻辑回归"))
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
assert result == {"title": "异步标题"}
assert result == {"title": "核心逻辑回归"}
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: False)
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
result = middleware.after_model({"messages": []}, runtime=MagicMock())
assert result == {"title": "同步标题"}
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
def test_sync_generate_title_with_model(self, monkeypatch):
"""Sync path calls model.invoke and produces a title."""
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content='"同步生成的标题"'))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "同步生成的标题"}
fake_model.invoke.assert_called_once()
def test_empty_title_falls_back(self, monkeypatch):
"""Empty model response triggers fallback title."""
_set_test_title_config(max_chars=50)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content=" "))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="空标题测试"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"] == "空标题测试"