mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-26 07:14:47 +08:00
fix(middleware): fix DanglingToolCallMiddleware inserting patches at wrong position (#904)
Previously used before_model which returned {"messages": patches}, causing
LangGraph's add_messages reducer to append patches at the end of the message
list. This resulted in invalid ordering (ToolMessage after a HumanMessage)
that LLMs reject with tool call ID mismatch errors.
Switch to wrap_model_call/awrap_model_call to insert synthetic ToolMessages
immediately after each dangling AIMessage before the request reaches the LLM,
without persisting the patches to state.
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,17 +4,23 @@ A dangling tool call occurs when an AIMessage contains tool_calls but there are
|
|||||||
no corresponding ToolMessages in the history (e.g., due to user interruption or
|
no corresponding ToolMessages in the history (e.g., due to user interruption or
|
||||||
request cancellation). This causes LLM errors due to incomplete message format.
|
request cancellation). This causes LLM errors due to incomplete message format.
|
||||||
|
|
||||||
This middleware runs before the model call to detect and patch such gaps by
|
This middleware intercepts the model call to detect and patch such gaps by
|
||||||
inserting synthetic ToolMessages with an error indicator.
|
inserting synthetic ToolMessages with an error indicator immediately after the
|
||||||
|
AIMessage that made the tool calls, ensuring correct message ordering.
|
||||||
|
|
||||||
|
Note: Uses wrap_model_call instead of before_model to ensure patches are inserted
|
||||||
|
at the correct positions (immediately after each dangling AIMessage), not appended
|
||||||
|
to the end of the message list as before_model + add_messages reducer would do.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
|
||||||
from langchain_core.messages import ToolMessage
|
from langchain_core.messages import ToolMessage
|
||||||
from langgraph.runtime import Runtime
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -23,33 +29,51 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
|
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
|
||||||
|
|
||||||
Scans the message history for AIMessages whose tool_calls lack corresponding
|
Scans the message history for AIMessages whose tool_calls lack corresponding
|
||||||
ToolMessages, and injects synthetic error responses so the LLM receives a
|
ToolMessages, and injects synthetic error responses immediately after the
|
||||||
well-formed conversation.
|
offending AIMessage so the LLM receives a well-formed conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _fix_dangling_tool_calls(self, state: AgentState) -> dict | None:
|
def _build_patched_messages(self, messages: list) -> list | None:
|
||||||
messages = state.get("messages", [])
|
"""Return a new message list with patches inserted at the correct positions.
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||||
|
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||||
|
Returns None if no patches are needed.
|
||||||
|
"""
|
||||||
# Collect IDs of all existing ToolMessages
|
# Collect IDs of all existing ToolMessages
|
||||||
existing_tool_msg_ids: set[str] = set()
|
existing_tool_msg_ids: set[str] = set()
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||||
|
|
||||||
# Find dangling tool calls and build patch messages
|
# Check if any patching is needed
|
||||||
patches: list[ToolMessage] = []
|
needs_patch = False
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if getattr(msg, "type", None) != "ai":
|
if getattr(msg, "type", None) != "ai":
|
||||||
continue
|
continue
|
||||||
tool_calls = getattr(msg, "tool_calls", None)
|
for tc in getattr(msg, "tool_calls", None) or []:
|
||||||
if not tool_calls:
|
|
||||||
continue
|
|
||||||
for tc in tool_calls:
|
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||||
patches.append(
|
needs_patch = True
|
||||||
|
break
|
||||||
|
if needs_patch:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not needs_patch:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build new list with patches inserted right after each dangling AIMessage
|
||||||
|
patched: list = []
|
||||||
|
patched_ids: set[str] = set()
|
||||||
|
patch_count = 0
|
||||||
|
for msg in messages:
|
||||||
|
patched.append(msg)
|
||||||
|
if getattr(msg, "type", None) != "ai":
|
||||||
|
continue
|
||||||
|
for tc in getattr(msg, "tool_calls", None) or []:
|
||||||
|
tc_id = tc.get("id")
|
||||||
|
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||||
|
patched.append(
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
content="[Tool call was interrupted and did not return a result.]",
|
content="[Tool call was interrupted and did not return a result.]",
|
||||||
tool_call_id=tc_id,
|
tool_call_id=tc_id,
|
||||||
@@ -57,18 +81,30 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
existing_tool_msg_ids.add(tc_id)
|
patched_ids.add(tc_id)
|
||||||
|
patch_count += 1
|
||||||
|
|
||||||
if not patches:
|
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||||
return None
|
return patched
|
||||||
|
|
||||||
logger.warning(f"Injecting {len(patches)} placeholder ToolMessage(s) for dangling tool calls")
|
|
||||||
return {"messages": patches}
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
def wrap_model_call(
|
||||||
return self._fix_dangling_tool_calls(state)
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], ModelResponse],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
patched = self._build_patched_messages(request.messages)
|
||||||
|
if patched is not None:
|
||||||
|
request = request.override(messages=patched)
|
||||||
|
return handler(request)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
async def awrap_model_call(
|
||||||
return self._fix_dangling_tool_calls(state)
|
self,
|
||||||
|
request: ModelRequest,
|
||||||
|
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||||
|
) -> ModelCallResult:
|
||||||
|
patched = self._build_patched_messages(request.messages)
|
||||||
|
if patched is not None:
|
||||||
|
request = request.override(messages=patched)
|
||||||
|
return await handler(request)
|
||||||
|
|||||||
Reference in New Issue
Block a user