From d27a7a5f54f01f4c96db517d57473bdd56261b13 Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Wed, 25 Feb 2026 22:29:33 +0800 Subject: [PATCH] fix(middleware): fix DanglingToolCallMiddleware inserting patches at wrong position (#904) Previously used before_model which returned {"messages": patches}, causing LangGraph's add_messages reducer to append patches at the end of the message list. This resulted in invalid ordering (ToolMessage after a HumanMessage) that LLMs reject with tool call ID mismatch errors. Switch to wrap_model_call/awrap_model_call to insert synthetic ToolMessages immediately after each dangling AIMessage before the request reaches the LLM, without persisting the patches to state. Co-authored-by: Claude Sonnet 4.6 --- .../dangling_tool_call_middleware.py | 88 +++++++++++++------ 1 file changed, 62 insertions(+), 26 deletions(-) diff --git a/backend/src/agents/middlewares/dangling_tool_call_middleware.py b/backend/src/agents/middlewares/dangling_tool_call_middleware.py index 7d3104d..5516ffb 100644 --- a/backend/src/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/src/agents/middlewares/dangling_tool_call_middleware.py @@ -4,17 +4,23 @@ A dangling tool call occurs when an AIMessage contains tool_calls but there are no corresponding ToolMessages in the history (e.g., due to user interruption or request cancellation). This causes LLM errors due to incomplete message format. -This middleware runs before the model call to detect and patch such gaps by -inserting synthetic ToolMessages with an error indicator. +This middleware intercepts the model call to detect and patch such gaps by +inserting synthetic ToolMessages with an error indicator immediately after the +AIMessage that made the tool calls, ensuring correct message ordering. + +Note: Uses wrap_model_call instead of before_model to ensure patches are inserted +at the correct positions (immediately after each dangling AIMessage), not appended +to the end of the message list as before_model + add_messages reducer would do. """ import logging +from collections.abc import Awaitable, Callable from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse from langchain_core.messages import ToolMessage -from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @@ -23,33 +29,51 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): """Inserts placeholder ToolMessages for dangling tool calls before model invocation. Scans the message history for AIMessages whose tool_calls lack corresponding - ToolMessages, and injects synthetic error responses so the LLM receives a - well-formed conversation. + ToolMessages, and injects synthetic error responses immediately after the + offending AIMessage so the LLM receives a well-formed conversation. """ - def _fix_dangling_tool_calls(self, state: AgentState) -> dict | None: - messages = state.get("messages", []) - if not messages: - return None + def _build_patched_messages(self, messages: list) -> list | None: + """Return a new message list with patches inserted at the correct positions. + For each AIMessage with dangling tool_calls (no corresponding ToolMessage), + a synthetic ToolMessage is inserted immediately after that AIMessage. + Returns None if no patches are needed. + """ # Collect IDs of all existing ToolMessages existing_tool_msg_ids: set[str] = set() for msg in messages: if isinstance(msg, ToolMessage): existing_tool_msg_ids.add(msg.tool_call_id) - # Find dangling tool calls and build patch messages - patches: list[ToolMessage] = [] + # Check if any patching is needed + needs_patch = False for msg in messages: if getattr(msg, "type", None) != "ai": continue - tool_calls = getattr(msg, "tool_calls", None) - if not tool_calls: - continue - for tc in tool_calls: + for tc in getattr(msg, "tool_calls", None) or []: tc_id = tc.get("id") if tc_id and tc_id not in existing_tool_msg_ids: - patches.append( + needs_patch = True + break + if needs_patch: + break + + if not needs_patch: + return None + + # Build new list with patches inserted right after each dangling AIMessage + patched: list = [] + patched_ids: set[str] = set() + patch_count = 0 + for msg in messages: + patched.append(msg) + if getattr(msg, "type", None) != "ai": + continue + for tc in getattr(msg, "tool_calls", None) or []: + tc_id = tc.get("id") + if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: + patched.append( ToolMessage( content="[Tool call was interrupted and did not return a result.]", tool_call_id=tc_id, @@ -57,18 +81,30 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): status="error", ) ) - existing_tool_msg_ids.add(tc_id) + patched_ids.add(tc_id) + patch_count += 1 - if not patches: - return None - - logger.warning(f"Injecting {len(patches)} placeholder ToolMessage(s) for dangling tool calls") - return {"messages": patches} + logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") + return patched @override - def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: - return self._fix_dangling_tool_calls(state) + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + patched = self._build_patched_messages(request.messages) + if patched is not None: + request = request.override(messages=patched) + return handler(request) @override - async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None: - return self._fix_dangling_tool_calls(state) + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + patched = self._build_patched_messages(request.messages) + if patched is not None: + request = request.override(messages=patched) + return await handler(request)