Files
deer-flow/backend/src/agents/middlewares/dangling_tool_call_middleware.py
hetao caf12da0f2 feat: add DanglingToolCallMiddleware and SubagentLimitMiddleware
Add two new middlewares to improve robustness of the agent pipeline:
- DanglingToolCallMiddleware injects placeholder ToolMessages for
  interrupted tool calls, preventing LLM errors from malformed history
- SubagentLimitMiddleware truncates excess parallel task tool calls at
  the model response level, replacing the runtime check in task_tool

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 13:22:49 +08:00

75 lines
2.8 KiB
Python

"""Middleware to fix dangling tool calls in message history.
A dangling tool call occurs when an AIMessage contains tool_calls but there are
no corresponding ToolMessages in the history (e.g., due to user interruption or
request cancellation). This causes LLM errors due to incomplete message format.
This middleware runs before the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator.
"""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
Scans the message history for AIMessages whose tool_calls lack corresponding
ToolMessages, and injects synthetic error responses so the LLM receives a
well-formed conversation.
"""
def _fix_dangling_tool_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
for msg in messages:
if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id)
# Find dangling tool calls and build patch messages
patches: list[ToolMessage] = []
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
continue
for tc in tool_calls:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
patches.append(
ToolMessage(
content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id,
name=tc.get("name", "unknown"),
status="error",
)
)
existing_tool_msg_ids.add(tc_id)
if not patches:
return None
logger.warning(f"Injecting {len(patches)} placeholder ToolMessage(s) for dangling tool calls")
return {"messages": patches}
@override
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._fix_dangling_tool_calls(state)
@override
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._fix_dangling_tool_calls(state)