fix: add sync after_model to TitleMiddleware (#1190)

This commit is contained in:
greatmengqi
2026-03-19 15:46:31 +08:00
committed by GitHub
parent f67c3d2c9e
commit accf5b5f8e
2 changed files with 120 additions and 35 deletions

View File

@@ -1,5 +1,6 @@
"""Middleware for automatic thread title generation.""" """Middleware for automatic thread title generation."""
import logging
from typing import NotRequired, override from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
@@ -9,6 +10,8 @@ from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
class TitleMiddlewareState(AgentState): class TitleMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema.""" """Compatible with the `ThreadState` schema."""
@@ -62,49 +65,85 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
# Generate title after first complete exchange # Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1 return len(user_messages) == 1 and len(assistant_messages) >= 1
async def _generate_title(self, state: TitleMiddlewareState) -> str: def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
"""Generate a concise title based on the conversation.""" """Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config() config = get_title_config()
messages = state.get("messages", []) messages = state.get("messages", [])
# Get first user message and first assistant response
user_msg_content = next((m.content for m in messages if m.type == "human"), "") user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "") assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content) user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._normalize_content(assistant_msg_content) assistant_msg = self._normalize_content(assistant_msg_content)
# Use a lightweight model to generate title
model = create_chat_model(thinking_enabled=False)
prompt = config.prompt_template.format( prompt = config.prompt_template.format(
max_words=config.max_words, max_words=config.max_words,
user_msg=user_msg[:500], user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500], assistant_msg=assistant_msg[:500],
) )
return prompt, user_msg
def _parse_title(self, content: object) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (sync)")
title = self._fallback_title(user_msg)
return {"title": title}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Asynchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try: try:
response = await model.ainvoke(prompt) response = await model.ainvoke(prompt)
title_content = self._normalize_content(response.content) title = self._parse_title(response.content)
title = title_content.strip().strip('"').strip("'") if not title:
# Limit to max characters title = self._fallback_title(user_msg)
return title[: config.max_chars] if len(title) > config.max_chars else title except Exception:
except Exception as e: logger.exception("Failed to generate title (async)")
print(f"Failed to generate title: {e}") title = self._fallback_title(user_msg)
# Fallback: use first part of user message (by character count)
fallback_chars = min(config.max_chars, 50) # Use max_chars or 50, whichever is smaller return {"title": title}
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..." @override
return user_msg if user_msg else "New Conversation" def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return self._generate_title_result(state)
@override @override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None: async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
"""Generate and set thread title after the first agent response.""" return await self._agenerate_title_result(state)
if self._should_generate_title(state):
title = await self._generate_title(state)
print(f"Generated thread title: {title}")
# Store title in state (will be persisted by checkpointer if configured)
return {"title": title}
return None

View File

@@ -86,7 +86,8 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="好的,先确认需求"), AIMessage(content="好的,先确认需求"),
] ]
} }
title = asyncio.run(middleware._generate_title(state)) result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
assert '"' not in title assert '"' not in title
assert "'" not in title assert "'" not in title
@@ -111,7 +112,8 @@ class TestTitleMiddlewareCoreLogic:
] ]
} }
title = asyncio.run(middleware._generate_title(state)) result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
prompt = fake_model.ainvoke.await_args.args[0] prompt = fake_model.ainvoke.await_args.args[0]
assert "请帮我总结这段代码" in prompt assert "请帮我总结这段代码" in prompt
@@ -135,20 +137,64 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="收到"), AIMessage(content="收到"),
] ]
} }
title = asyncio.run(middleware._generate_title(state)) result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text. # Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...") assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述") assert title.startswith("这是一个非常长的问题描述")
def test_after_agent_returns_title_only_when_needed(self, monkeypatch): def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
middleware = TitleMiddleware() middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: True)
monkeypatch.setattr(middleware, "_generate_title", AsyncMock(return_value="核心逻辑回归"))
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
assert result == {"title": "异步标题"}
assert result == {"title": "核心逻辑回归"} monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: False)
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
result = middleware.after_model({"messages": []}, runtime=MagicMock())
assert result == {"title": "同步标题"}
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
def test_sync_generate_title_with_model(self, monkeypatch):
"""Sync path calls model.invoke and produces a title."""
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content='"同步生成的标题"'))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "同步生成的标题"}
fake_model.invoke.assert_called_once()
def test_empty_title_falls_back(self, monkeypatch):
"""Empty model response triggers fallback title."""
_set_test_title_config(max_chars=50)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content=" "))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="空标题测试"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"] == "空标题测试"