fix(middleware): fix DanglingToolCallMiddleware inserting patches at wrong position (#904)

Previously used before_model which returned {"messages": patches}, causing
LangGraph's add_messages reducer to append patches at the end of the message
list. This resulted in invalid ordering (ToolMessage after a HumanMessage)
that LLMs reject with tool call ID mismatch errors.

Switch to wrap_model_call/awrap_model_call to insert synthetic ToolMessages
immediately after each dangling AIMessage before the request reaches the LLM,
without persisting the patches to state.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut
2026-02-25 22:29:33 +08:00
committed by GitHub
parent 33595f0bac
commit d27a7a5f54

View File

@@ -4,17 +4,23 @@ A dangling tool call occurs when an AIMessage contains tool_calls but there are
no corresponding ToolMessages in the history (e.g., due to user interruption or no corresponding ToolMessages in the history (e.g., due to user interruption or
request cancellation). This causes LLM errors due to incomplete message format. request cancellation). This causes LLM errors due to incomplete message format.
This middleware runs before the model call to detect and patch such gaps by This middleware intercepts the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator. inserting synthetic ToolMessages with an error indicator immediately after the
AIMessage that made the tool calls, ensuring correct message ordering.
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
at the correct positions (immediately after each dangling AIMessage), not appended
to the end of the message list as before_model + add_messages reducer would do.
""" """
import logging import logging
from collections.abc import Awaitable, Callable
from typing import override from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage from langchain_core.messages import ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -23,33 +29,51 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation. """Inserts placeholder ToolMessages for dangling tool calls before model invocation.
Scans the message history for AIMessages whose tool_calls lack corresponding Scans the message history for AIMessages whose tool_calls lack corresponding
ToolMessages, and injects synthetic error responses so the LLM receives a ToolMessages, and injects synthetic error responses immediately after the
well-formed conversation. offending AIMessage so the LLM receives a well-formed conversation.
""" """
def _fix_dangling_tool_calls(self, state: AgentState) -> dict | None: def _build_patched_messages(self, messages: list) -> list | None:
messages = state.get("messages", []) """Return a new message list with patches inserted at the correct positions.
if not messages:
return None
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
a synthetic ToolMessage is inserted immediately after that AIMessage.
Returns None if no patches are needed.
"""
# Collect IDs of all existing ToolMessages # Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set() existing_tool_msg_ids: set[str] = set()
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage): if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id) existing_tool_msg_ids.add(msg.tool_call_id)
# Find dangling tool calls and build patch messages # Check if any patching is needed
patches: list[ToolMessage] = [] needs_patch = False
for msg in messages: for msg in messages:
if getattr(msg, "type", None) != "ai": if getattr(msg, "type", None) != "ai":
continue continue
tool_calls = getattr(msg, "tool_calls", None) for tc in getattr(msg, "tool_calls", None) or []:
if not tool_calls:
continue
for tc in tool_calls:
tc_id = tc.get("id") tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids: if tc_id and tc_id not in existing_tool_msg_ids:
patches.append( needs_patch = True
break
if needs_patch:
break
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = []
patched_ids: set[str] = set()
patch_count = 0
for msg in messages:
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in getattr(msg, "tool_calls", None) or []:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage( ToolMessage(
content="[Tool call was interrupted and did not return a result.]", content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id, tool_call_id=tc_id,
@@ -57,18 +81,30 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error", status="error",
) )
) )
existing_tool_msg_ids.add(tc_id) patched_ids.add(tc_id)
patch_count += 1
if not patches: logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return None return patched
logger.warning(f"Injecting {len(patches)} placeholder ToolMessage(s) for dangling tool calls")
return {"messages": patches}
@override @override
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None: def wrap_model_call(
return self._fix_dangling_tool_calls(state) self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return handler(request)
@override @override
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None: async def awrap_model_call(
return self._fix_dangling_tool_calls(state) self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
patched = self._build_patched_messages(request.messages)
if patched is not None:
request = request.override(messages=patched)
return await handler(request)