Files
deer-flow/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py

228 lines
8.2 KiB
Python
Raw Normal View History

feat: add LoopDetectionMiddleware to break repetitive tool call loops (#1056) * feat: add LoopDetectionMiddleware to break repetitive tool call loops Adds a new AgentMiddleware that detects when the agent is stuck calling the same tools with the same arguments repeatedly, which currently runs until the recursion limit kills the run. Detection: per-thread sliding window of tool call hashes (name + args). - Warn threshold (default 3): injects a "wrap up" system message - Hard limit (default 5): strips tool_calls, forcing final text output Includes 13 unit tests covering hashing, thresholds, window sliding, reset, and edge cases. Closes #1055 * fix: address PR #1056 review feedback for LoopDetectionMiddleware - Remove unused imports (Awaitable, Callable, ModelCallResult, ModelRequest, ModelResponse, AIMessage) from loop_detection_middleware - Remove unused pytest import from test file - Fix _hash_tool_calls sort key: sort by (name, serialized args) for deterministic hashing when multiple calls share the same tool name - Revert subagent_enabled default to False in agent.py to match DeerFlowClient and channel defaults - Remove unrelated SearxNG tools and Next.js rewrite changes from PR Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address 2nd round review feedback on PR #1056 - Inject loop warning only once per thread (prevents context bloat) - Add threading.Lock for thread-safe history mutations - Use runtime.context thread_id instead of workspace_path - Add LRU eviction for per-thread history (max 100 threads) - Add 5 new tests covering warn-once, LRU eviction, thread isolation, fallback thread_id, and lock presence Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: resolve lint errors in loop detection middleware tests Sort imports (I001) and remove unused _WARNING_MSG import (F401) to fix ruff lint failures in CI. * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-14 16:17:54 +02:00
"""Middleware to detect and break repetitive tool call loops.
P0 safety: prevents the agent from calling the same tool with the same
arguments indefinitely until the recursion limit kills the run.
Detection strategy:
1. After each model response, hash the tool calls (name + args).
2. Track recent hashes in a sliding window.
3. If the same hash appears >= warn_threshold times, inject a
"you are repeating yourself — wrap up" system message (once per hash).
4. If it appears >= hard_limit times, strip all tool_calls from the
response so the agent is forced to produce a final text answer.
"""
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import SystemMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
def _hash_tool_calls(tool_calls: list[dict]) -> str:
"""Deterministic hash of a set of tool calls (name + args).
This is intended to be order-independent: the same multiset of tool calls
should always produce the same hash, regardless of their input order.
"""
# First normalize each tool call to a minimal (name, args) structure.
normalized: list[dict] = []
for tc in tool_calls:
normalized.append(
{
"name": tc.get("name", ""),
"args": tc.get("args", {}),
}
)
# Sort by both name and a deterministic serialization of args so that
# permutations of the same multiset of calls yield the same ordering.
normalized.sort(
key=lambda tc: (
tc["name"],
json.dumps(tc["args"], sort_keys=True, default=str),
)
)
blob = json.dumps(normalized, sort_keys=True, default=str)
return hashlib.md5(blob.encode()).hexdigest()[:12]
_WARNING_MSG = (
"[LOOP DETECTED] You are repeating the same tool calls. "
"Stop calling tools and produce your final answer now. "
"If you cannot complete the task, summarize what you accomplished so far."
)
_HARD_STOP_MSG = (
"[FORCED STOP] Repeated tool calls exceeded the safety limit. "
"Producing final answer with results collected so far."
)
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
hard_limit: Number of identical tool call sets before stripping
tool_calls entirely. Default: 5.
window_size: Size of the sliding window for tracking calls.
Default: 20.
max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100.
"""
def __init__(
self,
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
):
super().__init__()
self.warn_threshold = warn_threshold
self.hard_limit = hard_limit
self.window_size = window_size
self.max_tracked_threads = max_tracked_threads
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id")
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
Must be called while holding self._lock.
"""
while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
Returns:
(warning_message_or_none, should_hard_stop)
"""
messages = state.get("messages", [])
if not messages:
return None, False
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None, False
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None, False
thread_id = self._get_thread_id(runtime)
call_hash = _hash_tool_calls(tool_calls)
with self._lock:
# Touch / create entry (move to end for LRU)
if thread_id in self._history:
self._history.move_to_end(thread_id)
else:
self._history[thread_id] = []
self._evict_if_needed()
history = self._history[thread_id]
history.append(call_hash)
if len(history) > self.window_size:
history[:] = history[-self.window_size:]
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
if count >= self.hard_limit:
logger.error(
"Loop hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _HARD_STOP_MSG, True
if count >= self.warn_threshold:
warned = self._warned[thread_id]
if call_hash not in warned:
warned.add(call_hash)
logger.warning(
"Repetitive tool calls detected — injecting warning",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _WARNING_MSG, False
# Warning already injected for this hash — suppress
return None, False
return None, False
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
if hard_stop:
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
stripped_msg = last_msg.model_copy(update={
"tool_calls": [],
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
})
return {"messages": [stripped_msg]}
if warning:
# Inject a system message warning the model
return {"messages": [SystemMessage(content=warning)]}
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
"""Clear tracking state. If thread_id given, clear only that thread."""
with self._lock:
if thread_id:
self._history.pop(thread_id, None)
self._warned.pop(thread_id, None)
else:
self._history.clear()
self._warned.clear()