mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 14:22:13 +08:00
feat: add LoopDetectionMiddleware to break repetitive tool call loops (#1056)
* feat: add LoopDetectionMiddleware to break repetitive tool call loops Adds a new AgentMiddleware that detects when the agent is stuck calling the same tools with the same arguments repeatedly, which currently runs until the recursion limit kills the run. Detection: per-thread sliding window of tool call hashes (name + args). - Warn threshold (default 3): injects a "wrap up" system message - Hard limit (default 5): strips tool_calls, forcing final text output Includes 13 unit tests covering hashing, thresholds, window sliding, reset, and edge cases. Closes #1055 * fix: address PR #1056 review feedback for LoopDetectionMiddleware - Remove unused imports (Awaitable, Callable, ModelCallResult, ModelRequest, ModelResponse, AIMessage) from loop_detection_middleware - Remove unused pytest import from test file - Fix _hash_tool_calls sort key: sort by (name, serialized args) for deterministic hashing when multiple calls share the same tool name - Revert subagent_enabled default to False in agent.py to match DeerFlowClient and channel defaults - Remove unrelated SearxNG tools and Next.js rewrite changes from PR Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address 2nd round review feedback on PR #1056 - Inject loop warning only once per thread (prevents context bloat) - Add threading.Lock for thread-safe history mutations - Use runtime.context thread_id instead of workspace_path - Add LRU eviction for per-thread history (max 100 threads) - Add 5 new tests covering warn-once, LRU eviction, thread isolation, fallback thread_id, and lock presence Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: resolve lint errors in loop detection middleware tests Sort imports (I001) and remove unused _WARNING_MSG import (F401) to fix ruff lint failures in CI. * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||
from src.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from src.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
|
||||
from src.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
@@ -245,6 +246,9 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
|
||||
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
|
||||
|
||||
# LoopDetectionMiddleware — detect and break repetitive tool call loops
|
||||
middlewares.append(LoopDetectionMiddleware())
|
||||
|
||||
# ClarificationMiddleware should always be last
|
||||
middlewares.append(ClarificationMiddleware())
|
||||
return middlewares
|
||||
|
||||
227
backend/src/agents/middlewares/loop_detection_middleware.py
Normal file
227
backend/src/agents/middlewares/loop_detection_middleware.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Middleware to detect and break repetitive tool call loops.
|
||||
|
||||
P0 safety: prevents the agent from calling the same tool with the same
|
||||
arguments indefinitely until the recursion limit kills the run.
|
||||
|
||||
Detection strategy:
|
||||
1. After each model response, hash the tool calls (name + args).
|
||||
2. Track recent hashes in a sliding window.
|
||||
3. If the same hash appears >= warn_threshold times, inject a
|
||||
"you are repeating yourself — wrap up" system message (once per hash).
|
||||
4. If it appears >= hard_limit times, strip all tool_calls from the
|
||||
response so the agent is forced to produce a final text answer.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
|
||||
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
|
||||
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
|
||||
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
|
||||
|
||||
|
||||
def _hash_tool_calls(tool_calls: list[dict]) -> str:
|
||||
"""Deterministic hash of a set of tool calls (name + args).
|
||||
|
||||
This is intended to be order-independent: the same multiset of tool calls
|
||||
should always produce the same hash, regardless of their input order.
|
||||
"""
|
||||
# First normalize each tool call to a minimal (name, args) structure.
|
||||
normalized: list[dict] = []
|
||||
for tc in tool_calls:
|
||||
normalized.append(
|
||||
{
|
||||
"name": tc.get("name", ""),
|
||||
"args": tc.get("args", {}),
|
||||
}
|
||||
)
|
||||
|
||||
# Sort by both name and a deterministic serialization of args so that
|
||||
# permutations of the same multiset of calls yield the same ordering.
|
||||
normalized.sort(
|
||||
key=lambda tc: (
|
||||
tc["name"],
|
||||
json.dumps(tc["args"], sort_keys=True, default=str),
|
||||
)
|
||||
)
|
||||
blob = json.dumps(normalized, sort_keys=True, default=str)
|
||||
return hashlib.md5(blob.encode()).hexdigest()[:12]
|
||||
|
||||
|
||||
_WARNING_MSG = (
|
||||
"[LOOP DETECTED] You are repeating the same tool calls. "
|
||||
"Stop calling tools and produce your final answer now. "
|
||||
"If you cannot complete the task, summarize what you accomplished so far."
|
||||
)
|
||||
|
||||
_HARD_STOP_MSG = (
|
||||
"[FORCED STOP] Repeated tool calls exceeded the safety limit. "
|
||||
"Producing final answer with results collected so far."
|
||||
)
|
||||
|
||||
|
||||
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Detects and breaks repetitive tool call loops.
|
||||
|
||||
Args:
|
||||
warn_threshold: Number of identical tool call sets before injecting
|
||||
a warning message. Default: 3.
|
||||
hard_limit: Number of identical tool call sets before stripping
|
||||
tool_calls entirely. Default: 5.
|
||||
window_size: Size of the sliding window for tracking calls.
|
||||
Default: 20.
|
||||
max_tracked_threads: Maximum number of threads to track before
|
||||
evicting the least recently used. Default: 100.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
|
||||
hard_limit: int = _DEFAULT_HARD_LIMIT,
|
||||
window_size: int = _DEFAULT_WINDOW_SIZE,
|
||||
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
|
||||
):
|
||||
super().__init__()
|
||||
self.warn_threshold = warn_threshold
|
||||
self.hard_limit = hard_limit
|
||||
self.window_size = window_size
|
||||
self.max_tracked_threads = max_tracked_threads
|
||||
self._lock = threading.Lock()
|
||||
# Per-thread tracking using OrderedDict for LRU eviction
|
||||
self._history: OrderedDict[str, list[str]] = OrderedDict()
|
||||
self._warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return "default"
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
|
||||
Must be called while holding self._lock.
|
||||
"""
|
||||
while len(self._history) > self.max_tracked_threads:
|
||||
evicted_id, _ = self._history.popitem(last=False)
|
||||
self._warned.pop(evicted_id, None)
|
||||
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
|
||||
|
||||
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
|
||||
"""Track tool calls and check for loops.
|
||||
|
||||
Returns:
|
||||
(warning_message_or_none, should_hard_stop)
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None, False
|
||||
|
||||
last_msg = messages[-1]
|
||||
if getattr(last_msg, "type", None) != "ai":
|
||||
return None, False
|
||||
|
||||
tool_calls = getattr(last_msg, "tool_calls", None)
|
||||
if not tool_calls:
|
||||
return None, False
|
||||
|
||||
thread_id = self._get_thread_id(runtime)
|
||||
call_hash = _hash_tool_calls(tool_calls)
|
||||
|
||||
with self._lock:
|
||||
# Touch / create entry (move to end for LRU)
|
||||
if thread_id in self._history:
|
||||
self._history.move_to_end(thread_id)
|
||||
else:
|
||||
self._history[thread_id] = []
|
||||
self._evict_if_needed()
|
||||
|
||||
history = self._history[thread_id]
|
||||
history.append(call_hash)
|
||||
if len(history) > self.window_size:
|
||||
history[:] = history[-self.window_size:]
|
||||
|
||||
count = history.count(call_hash)
|
||||
tool_names = [tc.get("name", "?") for tc in tool_calls]
|
||||
|
||||
if count >= self.hard_limit:
|
||||
logger.error(
|
||||
"Loop hard limit reached — forcing stop",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _HARD_STOP_MSG, True
|
||||
|
||||
if count >= self.warn_threshold:
|
||||
warned = self._warned[thread_id]
|
||||
if call_hash not in warned:
|
||||
warned.add(call_hash)
|
||||
logger.warning(
|
||||
"Repetitive tool calls detected — injecting warning",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"call_hash": call_hash,
|
||||
"count": count,
|
||||
"tools": tool_names,
|
||||
},
|
||||
)
|
||||
return _WARNING_MSG, False
|
||||
# Warning already injected for this hash — suppress
|
||||
return None, False
|
||||
|
||||
return None, False
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
if hard_stop:
|
||||
# Strip tool_calls from the last AIMessage to force text output
|
||||
messages = state.get("messages", [])
|
||||
last_msg = messages[-1]
|
||||
stripped_msg = last_msg.model_copy(update={
|
||||
"tool_calls": [],
|
||||
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
|
||||
})
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
if warning:
|
||||
# Inject a system message warning the model
|
||||
return {"messages": [SystemMessage(content=warning)]}
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
def reset(self, thread_id: str | None = None) -> None:
|
||||
"""Clear tracking state. If thread_id given, clear only that thread."""
|
||||
with self._lock:
|
||||
if thread_id:
|
||||
self._history.pop(thread_id, None)
|
||||
self._warned.pop(thread_id, None)
|
||||
else:
|
||||
self._history.clear()
|
||||
self._warned.clear()
|
||||
231
backend/tests/test_loop_detection_middleware.py
Normal file
231
backend/tests/test_loop_detection_middleware.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Tests for LoopDetectionMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
|
||||
from src.agents.middlewares.loop_detection_middleware import (
|
||||
_HARD_STOP_MSG,
|
||||
LoopDetectionMiddleware,
|
||||
_hash_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime(thread_id="test-thread"):
|
||||
"""Build a minimal Runtime mock with context."""
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": thread_id}
|
||||
return runtime
|
||||
|
||||
|
||||
def _make_state(tool_calls=None, content=""):
|
||||
"""Build a minimal AgentState dict with an AIMessage."""
|
||||
msg = AIMessage(content=content, tool_calls=tool_calls or [])
|
||||
return {"messages": [msg]}
|
||||
|
||||
|
||||
def _bash_call(cmd="ls"):
|
||||
return {"name": "bash", "id": f"call_{cmd}", "args": {"command": cmd}}
|
||||
|
||||
|
||||
class TestHashToolCalls:
|
||||
def test_same_calls_same_hash(self):
|
||||
a = _hash_tool_calls([_bash_call("ls")])
|
||||
b = _hash_tool_calls([_bash_call("ls")])
|
||||
assert a == b
|
||||
|
||||
def test_different_calls_different_hash(self):
|
||||
a = _hash_tool_calls([_bash_call("ls")])
|
||||
b = _hash_tool_calls([_bash_call("pwd")])
|
||||
assert a != b
|
||||
|
||||
def test_order_independent(self):
|
||||
a = _hash_tool_calls([_bash_call("ls"), {"name": "read_file", "args": {"path": "/tmp"}}])
|
||||
b = _hash_tool_calls([{"name": "read_file", "args": {"path": "/tmp"}}, _bash_call("ls")])
|
||||
assert a == b
|
||||
|
||||
def test_empty_calls(self):
|
||||
h = _hash_tool_calls([])
|
||||
assert isinstance(h, str)
|
||||
assert len(h) > 0
|
||||
|
||||
|
||||
class TestLoopDetection:
|
||||
def test_no_tool_calls_returns_none(self):
|
||||
mw = LoopDetectionMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {"messages": [AIMessage(content="hello")]}
|
||||
result = mw._apply(state, runtime)
|
||||
assert result is None
|
||||
|
||||
def test_below_threshold_returns_none(self):
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
# First two identical calls — no warning
|
||||
for _ in range(2):
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_warn_at_threshold(self):
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Third identical call triggers warning
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
assert isinstance(msgs[0], SystemMessage)
|
||||
assert "LOOP DETECTED" in msgs[0].content
|
||||
|
||||
def test_warn_only_injected_once(self):
|
||||
"""Warning for the same hash should only be injected once per thread."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
# First two — no warning
|
||||
for _ in range(2):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Third — warning injected
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
# Fourth — warning already injected, should return None
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_hard_stop_at_limit(self):
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Fourth call triggers hard stop
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msgs = result["messages"]
|
||||
assert len(msgs) == 1
|
||||
# Hard stop strips tool_calls
|
||||
assert isinstance(msgs[0], AIMessage)
|
||||
assert msgs[0].tool_calls == []
|
||||
assert _HARD_STOP_MSG in msgs[0].content
|
||||
|
||||
def test_different_calls_dont_trigger(self):
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2)
|
||||
runtime = _make_runtime()
|
||||
|
||||
# Each call is different
|
||||
for i in range(10):
|
||||
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_window_sliding(self):
|
||||
mw = LoopDetectionMiddleware(warn_threshold=3, window_size=5)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
# Fill with 2 identical calls
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Push them out of the window with different calls
|
||||
for i in range(5):
|
||||
mw._apply(_make_state(tool_calls=[_bash_call(f"other_{i}")]), runtime)
|
||||
|
||||
# Now the original call should be fresh again — no warning
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_reset_clears_state(self):
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Would trigger warning, but reset first
|
||||
mw.reset()
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is None
|
||||
|
||||
def test_non_ai_message_ignored(self):
|
||||
mw = LoopDetectionMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {"messages": [SystemMessage(content="hello")]}
|
||||
result = mw._apply(state, runtime)
|
||||
assert result is None
|
||||
|
||||
def test_empty_messages_ignored(self):
|
||||
mw = LoopDetectionMiddleware()
|
||||
runtime = _make_runtime()
|
||||
result = mw._apply({"messages": []}, runtime)
|
||||
assert result is None
|
||||
|
||||
def test_thread_id_from_runtime_context(self):
|
||||
"""Thread ID should come from runtime.context, not state."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2)
|
||||
runtime_a = _make_runtime("thread-A")
|
||||
runtime_b = _make_runtime("thread-B")
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
# One call on thread A
|
||||
mw._apply(_make_state(tool_calls=call), runtime_a)
|
||||
# One call on thread B
|
||||
mw._apply(_make_state(tool_calls=call), runtime_b)
|
||||
|
||||
# Second call on thread A — triggers warning (2 >= warn_threshold)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime_a)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
# Second call on thread B — also triggers (independent tracking)
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime_b)
|
||||
assert result is not None
|
||||
assert "LOOP DETECTED" in result["messages"][0].content
|
||||
|
||||
def test_lru_eviction(self):
|
||||
"""Old threads should be evicted when max_tracked_threads is exceeded."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=3)
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
# Fill up 3 threads
|
||||
for i in range(3):
|
||||
runtime = _make_runtime(f"thread-{i}")
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Add a 4th thread — should evict thread-0
|
||||
runtime_new = _make_runtime("thread-new")
|
||||
mw._apply(_make_state(tool_calls=call), runtime_new)
|
||||
|
||||
assert "thread-0" not in mw._history
|
||||
assert "thread-new" in mw._history
|
||||
assert len(mw._history) == 3
|
||||
|
||||
def test_thread_safe_mutations(self):
|
||||
"""Verify lock is used for mutations (basic structural test)."""
|
||||
mw = LoopDetectionMiddleware()
|
||||
# The middleware should have a lock attribute
|
||||
assert hasattr(mw, "_lock")
|
||||
assert isinstance(mw._lock, type(mw._lock))
|
||||
|
||||
def test_fallback_thread_id_when_missing(self):
|
||||
"""When runtime context has no thread_id, should use 'default'."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2)
|
||||
runtime = MagicMock()
|
||||
runtime.context = {}
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert "default" in mw._history
|
||||
Reference in New Issue
Block a user