fix(middleware): fix DanglingToolCallMiddleware inserting patches at wrong position (#904)

Previously used before_model which returned {"messages": patches}, causing
LangGraph's add_messages reducer to append patches at the end of the message
list. This resulted in invalid ordering (ToolMessage after a HumanMessage)
that LLMs reject with tool call ID mismatch errors.

Switch to wrap_model_call/awrap_model_call to insert synthetic ToolMessages
immediately after each dangling AIMessage before the request reaches the LLM,
without persisting the patches to state.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut
2026-02-25 22:29:33 +08:00
committed by GitHub
parent 33595f0bac
commit d27a7a5f54

View File

@@ -4,17 +4,23 @@ A dangling tool call occurs when an AIMessage contains tool_calls but there are
no corresponding ToolMessages in the history (e.g., due to user interruption or
request cancellation). This causes LLM errors due to incomplete message format.
This middleware runs before the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator.
This middleware intercepts the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator immediately after the
AIMessage that made the tool calls, ensuring correct message ordering.
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
at the correct positions (immediately after each dangling AIMessage), not appended
to the end of the message list as before_model + add_messages reducer would do.
"""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@@ -23,33 +29,51 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
Scans the message history for AIMessages whose tool_calls lack corresponding
ToolMessages, and injects synthetic error responses so the LLM receives a
well-formed conversation.
ToolMessages, and injects synthetic error responses immediately after the
offending AIMessage so the LLM receives a well-formed conversation.
"""
def _fix_dangling_tool_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
def _build_patched_messages(self, messages: list) -> list | None:
"""Return a new message list with patches inserted at the correct positions.
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
a synthetic ToolMessage is inserted immediately after that AIMessage.
Returns None if no patches are needed.
"""
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
for msg in messages:
if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id)
# Find dangling tool calls and build patch messages
patches: list[ToolMessage] = []
# Check if any patching is needed
needs_patch = False
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
continue
for tc in tool_calls:
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
patches.append(
needs_patch = True
break
if needs_patch:
break
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = []
patched_ids: set[str] = set()
patch_count = 0
for msg in messages:
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage(
content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id,
@@ -57,18 +81,30 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error",
)
)
existing_tool_msg_ids.add(tc_id)
patched_ids.add(tc_id)
patch_count += 1
if not patches:
return None
logger.warning(f"Injecting {len(patches)} placeholder ToolMessage(s) for dangling tool calls")
return {"messages": patches}
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return patched
@override
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._fix_dangling_tool_calls(state)
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return handler(request)
@override
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._fix_dangling_tool_calls(state)
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return await handler(request)