fix: add sync after_model to TitleMiddleware (#1190)

This commit is contained in:
greatmengqi
2026-03-19 15:46:31 +08:00
committed by GitHub
parent f67c3d2c9e
commit accf5b5f8e
2 changed files with 120 additions and 35 deletions

View File

@@ -1,5 +1,6 @@
"""Middleware for automatic thread title generation.""" """Middleware for automatic thread title generation."""
import logging
from typing import NotRequired, override from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
@@ -9,6 +10,8 @@ from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config from deerflow.config.title_config import get_title_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
class TitleMiddlewareState(AgentState): class TitleMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema.""" """Compatible with the `ThreadState` schema."""
@@ -62,49 +65,85 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
# Generate title after first complete exchange # Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1 return len(user_messages) == 1 and len(assistant_messages) >= 1
async def _generate_title(self, state: TitleMiddlewareState) -> str: def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
"""Generate a concise title based on the conversation.""" """Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config() config = get_title_config()
messages = state.get("messages", []) messages = state.get("messages", [])
# Get first user message and first assistant response
user_msg_content = next((m.content for m in messages if m.type == "human"), "") user_msg_content = next((m.content for m in messages if m.type == "human"), "")
assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "") assistant_msg_content = next((m.content for m in messages if m.type == "ai"), "")
user_msg = self._normalize_content(user_msg_content) user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._normalize_content(assistant_msg_content) assistant_msg = self._normalize_content(assistant_msg_content)
# Use a lightweight model to generate title
model = create_chat_model(thinking_enabled=False)
prompt = config.prompt_template.format( prompt = config.prompt_template.format(
max_words=config.max_words, max_words=config.max_words,
user_msg=user_msg[:500], user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500], assistant_msg=assistant_msg[:500],
) )
return prompt, user_msg
try: def _parse_title(self, content: object) -> str:
response = await model.ainvoke(prompt) """Normalize model output into a clean title string."""
title_content = self._normalize_content(response.content) config = get_title_config()
title_content = self._normalize_content(content)
title = title_content.strip().strip('"').strip("'") title = title_content.strip().strip('"').strip("'")
# Limit to max characters
return title[: config.max_chars] if len(title) > config.max_chars else title return title[: config.max_chars] if len(title) > config.max_chars else title
except Exception as e:
print(f"Failed to generate title: {e}") def _fallback_title(self, user_msg: str) -> str:
# Fallback: use first part of user message (by character count) config = get_title_config()
fallback_chars = min(config.max_chars, 50) # Use max_chars or 50, whichever is smaller fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars: if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..." return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation" return user_msg if user_msg else "New Conversation"
@override def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None: """Synchronously generate a title. Returns state update or None."""
"""Generate and set thread title after the first agent response.""" if not self._should_generate_title(state):
if self._should_generate_title(state): return None
title = await self._generate_title(state)
print(f"Generated thread title: {title}") prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (sync)")
title = self._fallback_title(user_msg)
# Store title in state (will be persisted by checkpointer if configured)
return {"title": title} return {"title": title}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Asynchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state):
return None return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = await model.ainvoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (async)")
title = self._fallback_title(user_msg)
return {"title": title}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return self._generate_title_result(state)
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return await self._agenerate_title_result(state)

View File

@@ -86,7 +86,8 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="好的,先确认需求"), AIMessage(content="好的,先确认需求"),
] ]
} }
title = asyncio.run(middleware._generate_title(state)) result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
assert '"' not in title assert '"' not in title
assert "'" not in title assert "'" not in title
@@ -111,7 +112,8 @@ class TestTitleMiddlewareCoreLogic:
] ]
} }
title = asyncio.run(middleware._generate_title(state)) result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
prompt = fake_model.ainvoke.await_args.args[0] prompt = fake_model.ainvoke.await_args.args[0]
assert "请帮我总结这段代码" in prompt assert "请帮我总结这段代码" in prompt
@@ -135,20 +137,64 @@ class TestTitleMiddlewareCoreLogic:
AIMessage(content="收到"), AIMessage(content="收到"),
] ]
} }
title = asyncio.run(middleware._generate_title(state)) result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
# Assert behavior (truncated fallback + ellipsis) without overfitting exact text. # Assert behavior (truncated fallback + ellipsis) without overfitting exact text.
assert title.endswith("...") assert title.endswith("...")
assert title.startswith("这是一个非常长的问题描述") assert title.startswith("这是一个非常长的问题描述")
def test_after_agent_returns_title_only_when_needed(self, monkeypatch): def test_aafter_model_delegates_to_async_helper(self, monkeypatch):
middleware = TitleMiddleware() middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: True)
monkeypatch.setattr(middleware, "_generate_title", AsyncMock(return_value="核心逻辑回归"))
monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value={"title": "异步标题"}))
result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) result = asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock()))
assert result == {"title": "异步标题"}
assert result == {"title": "核心逻辑回归"} monkeypatch.setattr(middleware, "_agenerate_title_result", AsyncMock(return_value=None))
monkeypatch.setattr(middleware, "_should_generate_title", lambda state: False)
assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None assert asyncio.run(middleware.aafter_model({"messages": []}, runtime=MagicMock())) is None
def test_after_model_sync_delegates_to_sync_helper(self, monkeypatch):
middleware = TitleMiddleware()
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value={"title": "同步标题"}))
result = middleware.after_model({"messages": []}, runtime=MagicMock())
assert result == {"title": "同步标题"}
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
def test_sync_generate_title_with_model(self, monkeypatch):
"""Sync path calls model.invoke and produces a title."""
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content='"同步生成的标题"'))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="请帮我写测试"),
AIMessage(content="好的"),
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "同步生成的标题"}
fake_model.invoke.assert_called_once()
def test_empty_title_falls_back(self, monkeypatch):
"""Empty model response triggers fallback title."""
_set_test_title_config(max_chars=50)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content=" "))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="空标题测试"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"] == "空标题测试"