Files
deer-flow/backend/src/agents/middlewares/subagent_limit_middleware.py
hetao 48e3039055 feat: add DanglingToolCallMiddleware and SubagentLimitMiddleware
Add two new middlewares to improve robustness of the agent pipeline:
- DanglingToolCallMiddleware injects placeholder ToolMessages for
  interrupted tool calls, preventing LLM errors from malformed history
- SubagentLimitMiddleware truncates excess parallel task tool calls at
  the model response level, replacing the runtime check in task_tool

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 13:22:49 +08:00

62 lines
2.3 KiB
Python

"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
import logging
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.subagents.executor import MAX_CONCURRENT_SUBAGENTS
logger = logging.getLogger(__name__)
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
"""Truncates excess 'task' tool calls from a single model response.
When an LLM generates more than MAX_CONCURRENT_SUBAGENTS parallel task tool calls
in one response, this middleware keeps only the first MAX_CONCURRENT_SUBAGENTS and
discards the rest. This is more reliable than prompt-based limits.
"""
def _truncate_task_calls(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None
# Count task tool calls
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
if len(task_indices) <= MAX_CONCURRENT_SUBAGENTS:
return None
# Build set of indices to drop (excess task calls beyond the limit)
indices_to_drop = set(task_indices[MAX_CONCURRENT_SUBAGENTS:])
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
dropped_count = len(indices_to_drop)
logger.warning(
f"Truncated {dropped_count} excess task tool call(s) from model response "
f"(limit: {MAX_CONCURRENT_SUBAGENTS})"
)
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._truncate_task_calls(state)