mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
feat(middleware): introduce TodoMiddleware for context-loss detection in todo management (#1041)
* feat(middleware): introduce TodoMiddleware for context-loss detection in todo management * Address PR #1041 review suggestions: todo reminder dedup, thread switching, artifact deselect, debug log (#8) * Initial plan * Handle all suggestions from PR #1041 review Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com> * fix(chat-box): prevent automatic deselection of artifacts when switching threads fix(hooks): reset thread state on new thread creation --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import SummarizationMiddleware, TodoListMiddleware
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||
@@ -11,6 +11,7 @@ from src.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from src.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from src.agents.thread_state import ThreadState
|
||||
@@ -80,14 +81,14 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
return SummarizationMiddleware(**kwargs)
|
||||
|
||||
|
||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoListMiddleware | None:
|
||||
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
|
||||
"""Create and configure the TodoList middleware.
|
||||
|
||||
Args:
|
||||
is_plan_mode: Whether to enable plan mode with TodoList middleware.
|
||||
|
||||
Returns:
|
||||
TodoListMiddleware instance if plan mode is enabled, None otherwise.
|
||||
TodoMiddleware instance if plan mode is enabled, None otherwise.
|
||||
"""
|
||||
if not is_plan_mode:
|
||||
return None
|
||||
@@ -192,7 +193,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
**Remember**: If you only need a few tool calls to complete a task and it's clear what to do, it's better to just do the task directly and NOT use this tool at all.
|
||||
"""
|
||||
|
||||
return TodoListMiddleware(system_prompt=system_prompt, tool_description=tool_description)
|
||||
return TodoMiddleware(system_prompt=system_prompt, tool_description=tool_description)
|
||||
|
||||
|
||||
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
|
||||
|
||||
100
backend/src/agents/middlewares/todo_middleware.py
Normal file
100
backend/src/agents/middlewares/todo_middleware.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Middleware that extends TodoListMiddleware with context-loss detection.
|
||||
|
||||
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
||||
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
||||
active context window. This middleware detects that situation and injects a
|
||||
reminder message so the model still knows about the outstanding todo list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
|
||||
def _todos_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
if tc.get("name") == "write_todos":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _reminder_in_messages(messages: list[Any]) -> bool:
|
||||
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _format_todos(todos: list[Todo]) -> str:
|
||||
"""Format a list of Todo items into a human-readable string."""
|
||||
lines: list[str] = []
|
||||
for todo in todos:
|
||||
status = todo.get("status", "pending")
|
||||
content = todo.get("content", "")
|
||||
lines.append(f"- [{status}] {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class TodoMiddleware(TodoListMiddleware):
|
||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||
|
||||
When the original `write_todos` tool call has been truncated from the message
|
||||
history (e.g., after summarization), the model loses awareness of the current
|
||||
todo list. This middleware detects that gap in `before_model` / `abefore_model`
|
||||
and injects a reminder message so the model can continue tracking progress.
|
||||
"""
|
||||
|
||||
@override
|
||||
def before_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime, # noqa: ARG002
|
||||
) -> dict[str, Any] | None:
|
||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||||
if not todos:
|
||||
return None
|
||||
|
||||
messages = state.get("messages") or []
|
||||
if _todos_in_messages(messages):
|
||||
# write_todos is still visible in context — nothing to do.
|
||||
return None
|
||||
|
||||
if _reminder_in_messages(messages):
|
||||
# A reminder was already injected and hasn't been truncated yet.
|
||||
return None
|
||||
|
||||
# The todo list exists in state but the original write_todos call is gone.
|
||||
# Inject a reminder as a HumanMessage so the model stays aware.
|
||||
formatted = _format_todos(todos)
|
||||
reminder = HumanMessage(
|
||||
name="todo_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"Your todo list from earlier is no longer visible in the current context window, "
|
||||
"but it is still active. Here is the current state:\n\n"
|
||||
f"{formatted}\n\n"
|
||||
"Continue tracking and updating this todo list as you work. "
|
||||
"Call `write_todos` whenever the status of any item changes.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"messages": [reminder]}
|
||||
|
||||
@override
|
||||
async def abefore_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
return self.before_model(state, runtime)
|
||||
@@ -99,7 +99,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
|
||||
"- Output MUST be a JSON array of strings only.\n\n"
|
||||
"Conversation:\n"
|
||||
f"{conversation}\n"
|
||||
).format(n=n, conversation=conversation)
|
||||
)
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
|
||||
@@ -58,8 +58,8 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
elif effective_wte.get("thinking", {}).get("type"):
|
||||
# Native langchain_anthropic: thinking is a direct constructor parameter
|
||||
kwargs.update({"thinking": {"type": "disabled"}})
|
||||
if not model_config.supports_reasoning_effort:
|
||||
kwargs.update({"reasoning_effort": None})
|
||||
if not model_config.supports_reasoning_effort and "reasoning_effort" in kwargs:
|
||||
del kwargs["reasoning_effort"]
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
|
||||
@@ -178,7 +178,6 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
raise SandboxRuntimeError("Thread ID not available in runtime context")
|
||||
|
||||
provider = get_sandbox_provider()
|
||||
print(f"Lazy acquiring sandbox for thread {thread_id}")
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
|
||||
# Update runtime state - this persists across tool calls
|
||||
|
||||
@@ -50,13 +50,13 @@ const ChatBox: React.FC<{ children: React.ReactNode; threadId: string }> = ({
|
||||
// Update artifacts from the current thread
|
||||
setArtifacts(thread.values.artifacts);
|
||||
|
||||
// Deselect if the currently selected artifact no longer exists
|
||||
if (
|
||||
selectedArtifact &&
|
||||
!thread.values.artifacts?.includes(selectedArtifact)
|
||||
) {
|
||||
deselect();
|
||||
}
|
||||
// DO NOT automatically deselect the artifact when switching threads, because the artifacts auto discovering is not work now.
|
||||
// if (
|
||||
// selectedArtifact &&
|
||||
// !thread.values.artifacts?.includes(selectedArtifact)
|
||||
// ) {
|
||||
// deselect();
|
||||
// }
|
||||
|
||||
if (
|
||||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" &&
|
||||
|
||||
@@ -52,6 +52,10 @@ export function groupMessages<T>(
|
||||
}
|
||||
|
||||
for (const message of messages) {
|
||||
if (message.name === "todo_reminder") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.type === "human") {
|
||||
groups.push({ id: message.id, type: "human", messages: [message] });
|
||||
continue;
|
||||
|
||||
@@ -60,11 +60,12 @@ export function useThreadStream({
|
||||
|
||||
useEffect(() => {
|
||||
const normalizedThreadId = threadId ?? null;
|
||||
if (threadIdRef.current !== normalizedThreadId) {
|
||||
threadIdRef.current = normalizedThreadId;
|
||||
startedRef.current = false; // Reset for new thread
|
||||
if (!normalizedThreadId) {
|
||||
// Just reset for new thread creation when threadId becomes null/undefined
|
||||
startedRef.current = false;
|
||||
setOnStreamThreadId(normalizedThreadId);
|
||||
}
|
||||
threadIdRef.current = normalizedThreadId;
|
||||
}, [threadId]);
|
||||
|
||||
const _handleOnStart = useCallback((id: string) => {
|
||||
@@ -77,7 +78,6 @@ export function useThreadStream({
|
||||
const handleStreamStart = useCallback(
|
||||
(_threadId: string) => {
|
||||
threadIdRef.current = _threadId;
|
||||
setOnStreamThreadId(_threadId);
|
||||
_handleOnStart(_threadId);
|
||||
},
|
||||
[_handleOnStart],
|
||||
@@ -85,6 +85,7 @@ export function useThreadStream({
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
const updateSubtask = useUpdateSubtask();
|
||||
|
||||
const thread = useStream<AgentThreadState>({
|
||||
client: getAPIClient(isMock),
|
||||
assistantId: "lead_agent",
|
||||
@@ -93,6 +94,7 @@ export function useThreadStream({
|
||||
fetchStateHistory: { limit: 1 },
|
||||
onCreated(meta) {
|
||||
handleStreamStart(meta.thread_id);
|
||||
setOnStreamThreadId(meta.thread_id);
|
||||
},
|
||||
onLangChainEvent(event) {
|
||||
if (event.event === "on_tool_end") {
|
||||
|
||||
Reference in New Issue
Block a user