feat(middleware): introduce TodoMiddleware for context-loss detection in todo management (#1041)

* feat(middleware): introduce TodoMiddleware for context-loss detection in todo management

* Address PR #1041 review suggestions: todo reminder dedup, thread switching, artifact deselect, debug log (#8)

* Initial plan

* Handle all suggestions from PR #1041 review

Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>

* fix(chat-box): prevent automatic deselection of artifacts when switching threads
fix(hooks): reset thread state on new thread creation

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: foreleven <4785594+foreleven@users.noreply.github.com>
This commit is contained in:
JeffJiang
2026-03-10 11:24:53 +08:00
committed by GitHub
parent cf1c4a68ea
commit f5bd691172
8 changed files with 125 additions and 19 deletions

View File

@@ -1,7 +1,7 @@
import logging
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware, TodoListMiddleware
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.runnables import RunnableConfig
from src.agents.lead_agent.prompt import apply_prompt_template
@@ -11,6 +11,7 @@ from src.agents.middlewares.memory_middleware import MemoryMiddleware
from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from src.agents.middlewares.title_middleware import TitleMiddleware
from src.agents.middlewares.todo_middleware import TodoMiddleware
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
from src.agents.middlewares.view_image_middleware import ViewImageMiddleware
from src.agents.thread_state import ThreadState
@@ -80,14 +81,14 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
return SummarizationMiddleware(**kwargs)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoListMiddleware | None:
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
"""Create and configure the TodoList middleware.
Args:
is_plan_mode: Whether to enable plan mode with TodoList middleware.
Returns:
TodoListMiddleware instance if plan mode is enabled, None otherwise.
TodoMiddleware instance if plan mode is enabled, None otherwise.
"""
if not is_plan_mode:
return None
@@ -192,7 +193,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
**Remember**: If you only need a few tool calls to complete a task and it's clear what to do, it's better to just do the task directly and NOT use this tool at all.
"""
return TodoListMiddleware(system_prompt=system_prompt, tool_description=tool_description)
return TodoMiddleware(system_prompt=system_prompt, tool_description=tool_description)
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available

View File

@@ -0,0 +1,100 @@
"""Middleware that extends TodoListMiddleware with context-loss detection.
When the message history is truncated (e.g., by SummarizationMiddleware), the
original `write_todos` tool call and its ToolMessage can be scrolled out of the
active context window. This middleware detects that situation and injects a
reminder message so the model still knows about the outstanding todo list.
"""
from __future__ import annotations
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
def _todos_in_messages(messages: list[Any]) -> bool:
"""Return True if any AIMessage in *messages* contains a write_todos tool call."""
for msg in messages:
if isinstance(msg, AIMessage) and msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("name") == "write_todos":
return True
return False
def _reminder_in_messages(messages: list[Any]) -> bool:
"""Return True if a todo_reminder HumanMessage is already present in *messages*."""
for msg in messages:
if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_reminder":
return True
return False
def _format_todos(todos: list[Todo]) -> str:
"""Format a list of Todo items into a human-readable string."""
lines: list[str] = []
for todo in todos:
status = todo.get("status", "pending")
content = todo.get("content", "")
lines.append(f"- [{status}] {content}")
return "\n".join(lines)
class TodoMiddleware(TodoListMiddleware):
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
When the original `write_todos` tool call has been truncated from the message
history (e.g., after summarization), the model loses awareness of the current
todo list. This middleware detects that gap in `before_model` / `abefore_model`
and injects a reminder message so the model can continue tracking progress.
"""
@override
def before_model(
self,
state: PlanningState,
runtime: Runtime, # noqa: ARG002
) -> dict[str, Any] | None:
"""Inject a todo-list reminder when write_todos has left the context window."""
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
if not todos:
return None
messages = state.get("messages") or []
if _todos_in_messages(messages):
# write_todos is still visible in context — nothing to do.
return None
if _reminder_in_messages(messages):
# A reminder was already injected and hasn't been truncated yet.
return None
# The todo list exists in state but the original write_todos call is gone.
# Inject a reminder as a HumanMessage so the model stays aware.
formatted = _format_todos(todos)
reminder = HumanMessage(
name="todo_reminder",
content=(
"<system_reminder>\n"
"Your todo list from earlier is no longer visible in the current context window, "
"but it is still active. Here is the current state:\n\n"
f"{formatted}\n\n"
"Continue tracking and updating this todo list as you work. "
"Call `write_todos` whenever the status of any item changes.\n"
"</system_reminder>"
),
)
return {"messages": [reminder]}
@override
async def abefore_model(
self,
state: PlanningState,
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async version of before_model."""
return self.before_model(state, runtime)

View File

@@ -99,7 +99,7 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
"- Output MUST be a JSON array of strings only.\n\n"
"Conversation:\n"
f"{conversation}\n"
).format(n=n, conversation=conversation)
)
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)

View File

@@ -58,8 +58,8 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif effective_wte.get("thinking", {}).get("type"):
# Native langchain_anthropic: thinking is a direct constructor parameter
kwargs.update({"thinking": {"type": "disabled"}})
if not model_config.supports_reasoning_effort:
kwargs.update({"reasoning_effort": None})
if not model_config.supports_reasoning_effort and "reasoning_effort" in kwargs:
del kwargs["reasoning_effort"]
model_instance = model_class(**kwargs, **model_settings_from_config)

View File

@@ -178,7 +178,6 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider()
print(f"Lazy acquiring sandbox for thread {thread_id}")
sandbox_id = provider.acquire(thread_id)
# Update runtime state - this persists across tool calls

View File

@@ -50,13 +50,13 @@ const ChatBox: React.FC<{ children: React.ReactNode; threadId: string }> = ({
// Update artifacts from the current thread
setArtifacts(thread.values.artifacts);
// Deselect if the currently selected artifact no longer exists
if (
selectedArtifact &&
!thread.values.artifacts?.includes(selectedArtifact)
) {
deselect();
}
// DO NOT automatically deselect the artifact when switching threads, because the artifacts auto discovering is not work now.
// if (
// selectedArtifact &&
// !thread.values.artifacts?.includes(selectedArtifact)
// ) {
// deselect();
// }
if (
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" &&

View File

@@ -52,6 +52,10 @@ export function groupMessages<T>(
}
for (const message of messages) {
if (message.name === "todo_reminder") {
continue;
}
if (message.type === "human") {
groups.push({ id: message.id, type: "human", messages: [message] });
continue;

View File

@@ -60,11 +60,12 @@ export function useThreadStream({
useEffect(() => {
const normalizedThreadId = threadId ?? null;
if (threadIdRef.current !== normalizedThreadId) {
threadIdRef.current = normalizedThreadId;
startedRef.current = false; // Reset for new thread
if (!normalizedThreadId) {
// Just reset for new thread creation when threadId becomes null/undefined
startedRef.current = false;
setOnStreamThreadId(normalizedThreadId);
}
threadIdRef.current = normalizedThreadId;
}, [threadId]);
const _handleOnStart = useCallback((id: string) => {
@@ -77,7 +78,6 @@ export function useThreadStream({
const handleStreamStart = useCallback(
(_threadId: string) => {
threadIdRef.current = _threadId;
setOnStreamThreadId(_threadId);
_handleOnStart(_threadId);
},
[_handleOnStart],
@@ -85,6 +85,7 @@ export function useThreadStream({
const queryClient = useQueryClient();
const updateSubtask = useUpdateSubtask();
const thread = useStream<AgentThreadState>({
client: getAPIClient(isMock),
assistantId: "lead_agent",
@@ -93,6 +94,7 @@ export function useThreadStream({
fetchStateHistory: { limit: 1 },
onCreated(meta) {
handleStreamStart(meta.thread_id);
setOnStreamThreadId(meta.thread_id);
},
onLangChainEvent(event) {
if (event.event === "on_tool_end") {