mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
feat: add DanglingToolCallMiddleware and SubagentLimitMiddleware
Add two new middlewares to improve robustness of the agent pipeline: - DanglingToolCallMiddleware injects placeholder ToolMessages for interrupted tool calls, preventing LLM errors from malformed history - SubagentLimitMiddleware truncates excess parallel task tool calls at the model response level, replacing the runtime check in task_tool Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -24,7 +24,7 @@ deer-flow/
|
||||
│ ├── src/
|
||||
│ │ ├── agents/ # LangGraph agent system
|
||||
│ │ │ ├── lead_agent/ # Main agent (factory + system prompt)
|
||||
│ │ │ ├── middlewares/ # 9 middleware components
|
||||
│ │ │ ├── middlewares/ # 10 middleware components
|
||||
│ │ │ ├── memory/ # Memory extraction, queue, prompts
|
||||
│ │ │ └── thread_state.py # ThreadState schema
|
||||
│ │ ├── gateway/ # FastAPI Gateway API
|
||||
@@ -112,12 +112,14 @@ Middlewares execute in strict order in `src/agents/lead_agent/agent.py`:
|
||||
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`)
|
||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||
4. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||
5. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||
6. **TitleMiddleware** - Auto-generates thread title after first complete exchange
|
||||
7. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||
8. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||
9. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
||||
5. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||
6. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||
7. **TitleMiddleware** - Auto-generates thread title after first complete exchange
|
||||
8. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||
9. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||
10. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled)
|
||||
11. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
||||
|
||||
### Configuration System
|
||||
|
||||
@@ -185,7 +187,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
|
||||
**Built-in Agents**: `general-purpose` (all tools except `task`) and `bash` (command specialist)
|
||||
**Execution**: Dual thread pool - `_scheduler_pool` (3 workers) + `_execution_pool` (3 workers)
|
||||
**Concurrency**: `MAX_CONCURRENT_SUBAGENTS = 3` per trace, 15-minute timeout
|
||||
**Concurrency**: `MAX_CONCURRENT_SUBAGENTS = 3` enforced by `SubagentLimitMiddleware` (truncates excess tool calls in `after_model`), 15-minute timeout
|
||||
**Flow**: `task()` tool → `SubagentExecutor` → background thread → poll 5s → SSE events → result
|
||||
**Events**: `task_started`, `task_running`, `task_completed`/`task_failed`/`task_timed_out`
|
||||
|
||||
|
||||
@@ -4,7 +4,9 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||
from src.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
from src.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
@@ -174,6 +176,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
|
||||
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
|
||||
# UploadsMiddleware should be after ThreadDataMiddleware to access thread_id
|
||||
# DanglingToolCallMiddleware patches missing ToolMessages before model sees the history
|
||||
# SummarizationMiddleware should be early to reduce context before other processing
|
||||
# TodoListMiddleware should be before ClarificationMiddleware to allow todo management
|
||||
# TitleMiddleware generates title after first exchange
|
||||
@@ -189,7 +192,7 @@ def _build_middlewares(config: RunnableConfig):
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
"""
|
||||
middlewares = [ThreadDataMiddleware(), UploadsMiddleware(), SandboxMiddleware()]
|
||||
middlewares = [ThreadDataMiddleware(), UploadsMiddleware(), SandboxMiddleware(), DanglingToolCallMiddleware()]
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware()
|
||||
@@ -221,6 +224,11 @@ def _build_middlewares(config: RunnableConfig):
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
# Add SubagentLimitMiddleware to truncate excess parallel task calls
|
||||
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
|
||||
if subagent_enabled:
|
||||
middlewares.append(SubagentLimitMiddleware())
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Middleware to fix dangling tool calls in message history.
|
||||
|
||||
A dangling tool call occurs when an AIMessage contains tool_calls but there are
|
||||
no corresponding ToolMessages in the history (e.g., due to user interruption or
|
||||
request cancellation). This causes LLM errors due to incomplete message format.
|
||||
|
||||
This middleware runs before the model call to detect and patch such gaps by
|
||||
inserting synthetic ToolMessages with an error indicator.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Inserts placeholder ToolMessages for dangling tool calls before model invocation.
|
||||
|
||||
Scans the message history for AIMessages whose tool_calls lack corresponding
|
||||
ToolMessages, and injects synthetic error responses so the LLM receives a
|
||||
well-formed conversation.
|
||||
"""
|
||||
|
||||
def _fix_dangling_tool_calls(self, state: AgentState) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
|
||||
# Find dangling tool calls and build patch messages
|
||||
patches: list[ToolMessage] = []
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
tool_calls = getattr(msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
continue
|
||||
for tc in tool_calls:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
patches.append(
|
||||
ToolMessage(
|
||||
content="[Tool call was interrupted and did not return a result.]",
|
||||
tool_call_id=tc_id,
|
||||
name=tc.get("name", "unknown"),
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
existing_tool_msg_ids.add(tc_id)
|
||||
|
||||
if not patches:
|
||||
return None
|
||||
|
||||
logger.warning(f"Injecting {len(patches)} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return {"messages": patches}
|
||||
|
||||
@override
|
||||
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._fix_dangling_tool_calls(state)
|
||||
|
||||
@override
|
||||
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._fix_dangling_tool_calls(state)
|
||||
61
backend/src/agents/middlewares/subagent_limit_middleware.py
Normal file
61
backend/src/agents/middlewares/subagent_limit_middleware.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Middleware to enforce maximum concurrent subagent tool calls per model response."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.subagents.executor import MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SubagentLimitMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Truncates excess 'task' tool calls from a single model response.
|
||||
|
||||
When an LLM generates more than MAX_CONCURRENT_SUBAGENTS parallel task tool calls
|
||||
in one response, this middleware keeps only the first MAX_CONCURRENT_SUBAGENTS and
|
||||
discards the rest. This is more reliable than prompt-based limits.
|
||||
"""
|
||||
|
||||
def _truncate_task_calls(self, state: AgentState) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None
|
||||
|
||||
# Count task tool calls
|
||||
task_indices = [i for i, tc in enumerate(tool_calls) if tc.get("name") == "task"]
|
||||
if len(task_indices) <= MAX_CONCURRENT_SUBAGENTS:
|
||||
return None
|
||||
|
||||
# Build set of indices to drop (excess task calls beyond the limit)
|
||||
indices_to_drop = set(task_indices[MAX_CONCURRENT_SUBAGENTS:])
|
||||
truncated_tool_calls = [tc for i, tc in enumerate(tool_calls) if i not in indices_to_drop]
|
||||
|
||||
dropped_count = len(indices_to_drop)
|
||||
logger.warning(
|
||||
f"Truncated {dropped_count} excess task tool call(s) from model response "
|
||||
f"(limit: {MAX_CONCURRENT_SUBAGENTS})"
|
||||
)
|
||||
|
||||
# Replace the AIMessage with truncated tool_calls (same id triggers replacement)
|
||||
updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._truncate_task_calls(state)
|
||||
@@ -392,23 +392,6 @@ class SubagentExecutor:
|
||||
MAX_CONCURRENT_SUBAGENTS = 3
|
||||
|
||||
|
||||
def count_active_tasks_by_trace(trace_id: str) -> int:
|
||||
"""Count active (PENDING or RUNNING) background tasks for a given trace_id.
|
||||
|
||||
Args:
|
||||
trace_id: The trace ID linking tasks to a parent invocation.
|
||||
|
||||
Returns:
|
||||
Number of active tasks with the given trace_id.
|
||||
"""
|
||||
with _background_tasks_lock:
|
||||
return sum(
|
||||
1
|
||||
for task in _background_tasks.values()
|
||||
if task.trace_id == trace_id and task.status in (SubagentStatus.PENDING, SubagentStatus.RUNNING)
|
||||
)
|
||||
|
||||
|
||||
def get_background_task_result(task_id: str) -> SubagentResult | None:
|
||||
"""Get the result of a background task.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from langgraph.typing import ContextT
|
||||
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.subagents import SubagentExecutor, get_subagent_config
|
||||
from src.subagents.executor import MAX_CONCURRENT_SUBAGENTS, SubagentStatus, count_active_tasks_by_trace, get_background_task_result
|
||||
from src.subagents.executor import SubagentStatus, get_background_task_result
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,11 +86,6 @@ def task_tool(
|
||||
# Get or generate trace_id for distributed tracing
|
||||
trace_id = metadata.get("trace_id") or str(uuid.uuid4())[:8]
|
||||
|
||||
# Check sub-agent limit before creating a new one
|
||||
if trace_id and count_active_tasks_by_trace(trace_id) >= MAX_CONCURRENT_SUBAGENTS:
|
||||
logger.warning(f"[trace={trace_id}] Sub-agent limit reached ({MAX_CONCURRENT_SUBAGENTS}). Rejecting new task: {description}")
|
||||
return f"Error: Maximum number of concurrent sub-agents ({MAX_CONCURRENT_SUBAGENTS}) reached. Please wait for existing tasks to complete before launching new ones."
|
||||
|
||||
# Get available tools (excluding task tool to prevent nesting)
|
||||
# Lazy import to avoid circular dependency
|
||||
from src.tools import get_available_tools
|
||||
|
||||
Reference in New Issue
Block a user