feat: add DanglingToolCallMiddleware and SubagentLimitMiddleware

Add two new middlewares to improve robustness of the agent pipeline:
- DanglingToolCallMiddleware injects placeholder ToolMessages for
  interrupted tool calls, preventing LLM errors from malformed history
- SubagentLimitMiddleware truncates excess parallel task tool calls at
  the model response level, replacing the runtime check in task_tool

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hetao
2026-02-09 13:21:58 +08:00
parent 2c3ddbb9e5
commit caf12da0f2
6 changed files with 155 additions and 32 deletions

View File

@@ -0,0 +1,74 @@
"""Middleware to fix dangling tool calls in message history.
A dangling tool call occurs when an AIMessage contains tool_calls but there are
no corresponding ToolMessages in the history (e.g., due to user interruption or
request cancellation). This causes LLM errors due to incomplete message format.
This middleware runs before the model call to detect and patch such gaps by
inserting synthetic ToolMessages with an error indicator.
"""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
Scans the message history for AIMessages whose tool_calls lack corresponding
ToolMessages, and injects synthetic error responses so the LLM receives a
well-formed conversation.
"""
def _fix_dangling_tool_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
for msg in messages:
if isinstance(msg, ToolMessage):
existing_tool_msg_ids.add(msg.tool_call_id)
# Find dangling tool calls and build patch messages
patches: list[ToolMessage] = []
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
tool_calls = getattr(msg, "tool_calls", None)
if not tool_calls:
continue
for tc in tool_calls:
tc_id = tc.get("id")
if tc_id and tc_id not in existing_tool_msg_ids:
patches.append(
ToolMessage(
content="[Tool call was interrupted and did not return a result.]",
tool_call_id=tc_id,
name=tc.get("name", "unknown"),
status="error",
)
)
existing_tool_msg_ids.add(tc_id)
if not patches:
return None
logger.warning(f"Injecting {len(patches)} placeholder ToolMessage(s) for dangling tool calls")
return {"messages": patches}
@override
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._fix_dangling_tool_calls(state)
@override
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._fix_dangling_tool_calls(state)

View File

@@ -0,0 +1,61 @@
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
"""Truncates excess 'task' tool calls from a single model response.
When an LLM generates more than MAX_CONCURRENT_SUBAGENTS parallel task tool calls
in one response, this middleware keeps only the first MAX_CONCURRENT_SUBAGENTS and
discards the rest. This is more reliable than prompt-based limits.
"""
def _truncate_task_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None
# Count task tool calls
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
if len(task_indices) <= MAX_CONCURRENT_SUBAGENTS:
return None
# Build set of indices to drop (excess task calls beyond the limit)
indices_to_drop = set(task_indices[MAX_CONCURRENT_SUBAGENTS:])
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
dropped_count = len(indices_to_drop)
logger.warning(
f"Truncated {dropped_count} excess task tool call(s) from model response "
f"(limit: {MAX_CONCURRENT_SUBAGENTS})"
)
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)