mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
101 lines
3.8 KiB
Python
101 lines
3.8 KiB
Python
|
|
"""Middleware that extends TodoListMiddleware with context-loss detection.
|
||
|
|
|
||
|
|
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
||
|
|
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
||
|
|
active context window. This middleware detects that situation and injects a
|
||
|
|
reminder message so the model still knows about the outstanding todo list.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from typing import Any, override
|
||
|
|
|
||
|
|
from langchain.agents.middleware import TodoListMiddleware
|
||
|
|
from langchain.agents.middleware.todo import PlanningState, Todo
|
||
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
||
|
|
from langgraph.runtime import Runtime
|
||
|
|
|
||
|
|
|
||
|
|
def _todos_in_messages(messages: list[Any]) -> bool:
|
||
|
|
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
||
|
|
for msg in messages:
|
||
|
|
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||
|
|
for tc in msg.tool_calls:
|
||
|
|
if tc.get("name") == "write_todos":
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def _reminder_in_messages(messages: list[Any]) -> bool:
|
||
|
|
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
|
||
|
|
for msg in messages:
|
||
|
|
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def _format_todos(todos: list[Todo]) -> str:
|
||
|
|
"""Format a list of Todo items into a human-readable string."""
|
||
|
|
lines: list[str] = []
|
||
|
|
for todo in todos:
|
||
|
|
status = todo.get("status", "pending")
|
||
|
|
content = todo.get("content", "")
|
||
|
|
lines.append(f"- [{status}] {content}")
|
||
|
|
return "\n".join(lines)
|
||
|
|
|
||
|
|
|
||
|
|
class TodoMiddleware(TodoListMiddleware):
|
||
|
|
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||
|
|
|
||
|
|
When the original `write_todos` tool call has been truncated from the message
|
||
|
|
history (e.g., after summarization), the model loses awareness of the current
|
||
|
|
todo list. This middleware detects that gap in `before_model` / `abefore_model`
|
||
|
|
and injects a reminder message so the model can continue tracking progress.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@override
|
||
|
|
def before_model(
|
||
|
|
self,
|
||
|
|
state: PlanningState,
|
||
|
|
runtime: Runtime, # noqa: ARG002
|
||
|
|
) -> dict[str, Any] | None:
|
||
|
|
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||
|
|
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||
|
|
if not todos:
|
||
|
|
return None
|
||
|
|
|
||
|
|
messages = state.get("messages") or []
|
||
|
|
if _todos_in_messages(messages):
|
||
|
|
# write_todos is still visible in context — nothing to do.
|
||
|
|
return None
|
||
|
|
|
||
|
|
if _reminder_in_messages(messages):
|
||
|
|
# A reminder was already injected and hasn't been truncated yet.
|
||
|
|
return None
|
||
|
|
|
||
|
|
# The todo list exists in state but the original write_todos call is gone.
|
||
|
|
# Inject a reminder as a HumanMessage so the model stays aware.
|
||
|
|
formatted = _format_todos(todos)
|
||
|
|
reminder = HumanMessage(
|
||
|
|
name="todo_reminder",
|
||
|
|
content=(
|
||
|
|
"<system_reminder>\n"
|
||
|
|
"Your todo list from earlier is no longer visible in the current context window, "
|
||
|
|
"but it is still active. Here is the current state:\n\n"
|
||
|
|
f"{formatted}\n\n"
|
||
|
|
"Continue tracking and updating this todo list as you work. "
|
||
|
|
"Call `write_todos` whenever the status of any item changes.\n"
|
||
|
|
"</system_reminder>"
|
||
|
|
),
|
||
|
|
)
|
||
|
|
return {"messages": [reminder]}
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def abefore_model(
|
||
|
|
self,
|
||
|
|
state: PlanningState,
|
||
|
|
runtime: Runtime,
|
||
|
|
) -> dict[str, Any] | None:
|
||
|
|
"""Async version of before_model."""
|
||
|
|
return self.before_model(state, runtime)
|