feat: add LoopDetectionMiddleware to break repetitive tool call loops (#1056)

* feat: add LoopDetectionMiddleware to break repetitive tool call loops

Adds a new AgentMiddleware that detects when the agent is stuck calling
the same tools with the same arguments repeatedly, which currently runs
until the recursion limit kills the run.

Detection: per-thread sliding window of tool call hashes (name + args).
- Warn threshold (default 3): injects a "wrap up" system message
- Hard limit (default 5): strips tool_calls, forcing final text output

Includes 13 unit tests covering hashing, thresholds, window sliding,
reset, and edge cases.

Closes #1055

* fix: address PR #1056 review feedback for LoopDetectionMiddleware

- Remove unused imports (Awaitable, Callable, ModelCallResult,
  ModelRequest, ModelResponse, AIMessage) from loop_detection_middleware
- Remove unused pytest import from test file
- Fix _hash_tool_calls sort key: sort by (name, serialized args) for
  deterministic hashing when multiple calls share the same tool name
- Revert subagent_enabled default to False in agent.py to match
  DeerFlowClient and channel defaults
- Remove unrelated SearxNG tools and Next.js rewrite changes from PR

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: address 2nd round review feedback on PR #1056

- Inject loop warning only once per thread (prevents context bloat)
- Add threading.Lock for thread-safe history mutations
- Use runtime.context thread_id instead of workspace_path
- Add LRU eviction for per-thread history (max 100 threads)
- Add 5 new tests covering warn-once, LRU eviction, thread isolation,
  fallback thread_id, and lock presence

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: resolve lint errors in loop detection middleware tests

Sort imports (I001) and remove unused _WARNING_MSG import (F401)
to fix ruff lint failures in CI.

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
virtaava
2026-03-14 16:17:54 +02:00
committed by GitHub
parent bbd87df6eb
commit d18a9ae5aa
3 changed files with 462 additions and 0 deletions

View File

@@ -6,6 +6,7 @@ from langchain_core.runnables import RunnableConfig
from src.agents.lead_agent.prompt import apply_prompt_template
from src.agents.middlewares.clarification_middleware import ClarificationMiddleware
from src.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware
from src.agents.middlewares.memory_middleware import MemoryMiddleware
from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from src.agents.middlewares.title_middleware import TitleMiddleware
@@ -245,6 +246,9 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
middlewares.append(LoopDetectionMiddleware())
# ClarificationMiddleware should always be last
middlewares.append(ClarificationMiddleware())
return middlewares

View File

@@ -0,0 +1,227 @@
"""Middleware to detect and break repetitive tool call loops.
P0 safety: prevents the agent from calling the same tool with the same
arguments indefinitely until the recursion limit kills the run.
Detection strategy:
1. After each model response, hash the tool calls (name + args).
2. Track recent hashes in a sliding window.
3. If the same hash appears >= warn_threshold times, inject a
"you are repeating yourself — wrap up" system message (once per hash).
4. If it appears >= hard_limit times, strip all tool_calls from the
response so the agent is forced to produce a final text answer.
"""
import hashlib
import json
import logging
import threading
from collections import OrderedDict, defaultdict
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import SystemMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls
_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls
_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls
_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit
def _hash_tool_calls(tool_calls: list[dict]) -> str:
"""Deterministic hash of a set of tool calls (name + args).
This is intended to be order-independent: the same multiset of tool calls
should always produce the same hash, regardless of their input order.
"""
# First normalize each tool call to a minimal (name, args) structure.
normalized: list[dict] = []
for tc in tool_calls:
normalized.append(
{
"name": tc.get("name", ""),
"args": tc.get("args", {}),
}
)
# Sort by both name and a deterministic serialization of args so that
# permutations of the same multiset of calls yield the same ordering.
normalized.sort(
key=lambda tc: (
tc["name"],
json.dumps(tc["args"], sort_keys=True, default=str),
)
)
blob = json.dumps(normalized, sort_keys=True, default=str)
return hashlib.md5(blob.encode()).hexdigest()[:12]
_WARNING_MSG = (
"[LOOP DETECTED] You are repeating the same tool calls. "
"Stop calling tools and produce your final answer now. "
"If you cannot complete the task, summarize what you accomplished so far."
)
_HARD_STOP_MSG = (
"[FORCED STOP] Repeated tool calls exceeded the safety limit. "
"Producing final answer with results collected so far."
)
class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
"""Detects and breaks repetitive tool call loops.
Args:
warn_threshold: Number of identical tool call sets before injecting
a warning message. Default: 3.
hard_limit: Number of identical tool call sets before stripping
tool_calls entirely. Default: 5.
window_size: Size of the sliding window for tracking calls.
Default: 20.
max_tracked_threads: Maximum number of threads to track before
evicting the least recently used. Default: 100.
"""
def __init__(
self,
warn_threshold: int = _DEFAULT_WARN_THRESHOLD,
hard_limit: int = _DEFAULT_HARD_LIMIT,
window_size: int = _DEFAULT_WINDOW_SIZE,
max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS,
):
super().__init__()
self.warn_threshold = warn_threshold
self.hard_limit = hard_limit
self.window_size = window_size
self.max_tracked_threads = max_tracked_threads
self._lock = threading.Lock()
# Per-thread tracking using OrderedDict for LRU eviction
self._history: OrderedDict[str, list[str]] = OrderedDict()
self._warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id")
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
Must be called while holding self._lock.
"""
while len(self._history) > self.max_tracked_threads:
evicted_id, _ = self._history.popitem(last=False)
self._warned.pop(evicted_id, None)
logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id)
def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]:
"""Track tool calls and check for loops.
Returns:
(warning_message_or_none, should_hard_stop)
"""
messages = state.get("messages", [])
if not messages:
return None, False
last_msg = messages[-1]
if getattr(last_msg, "type", None) != "ai":
return None, False
tool_calls = getattr(last_msg, "tool_calls", None)
if not tool_calls:
return None, False
thread_id = self._get_thread_id(runtime)
call_hash = _hash_tool_calls(tool_calls)
with self._lock:
# Touch / create entry (move to end for LRU)
if thread_id in self._history:
self._history.move_to_end(thread_id)
else:
self._history[thread_id] = []
self._evict_if_needed()
history = self._history[thread_id]
history.append(call_hash)
if len(history) > self.window_size:
history[:] = history[-self.window_size:]
count = history.count(call_hash)
tool_names = [tc.get("name", "?") for tc in tool_calls]
if count >= self.hard_limit:
logger.error(
"Loop hard limit reached — forcing stop",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _HARD_STOP_MSG, True
if count >= self.warn_threshold:
warned = self._warned[thread_id]
if call_hash not in warned:
warned.add(call_hash)
logger.warning(
"Repetitive tool calls detected — injecting warning",
extra={
"thread_id": thread_id,
"call_hash": call_hash,
"count": count,
"tools": tool_names,
},
)
return _WARNING_MSG, False
# Warning already injected for this hash — suppress
return None, False
return None, False
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
if hard_stop:
# Strip tool_calls from the last AIMessage to force text output
messages = state.get("messages", [])
last_msg = messages[-1]
stripped_msg = last_msg.model_copy(update={
"tool_calls": [],
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
})
return {"messages": [stripped_msg]}
if warning:
# Inject a system message warning the model
return {"messages": [SystemMessage(content=warning)]}
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
"""Clear tracking state. If thread_id given, clear only that thread."""
with self._lock:
if thread_id:
self._history.pop(thread_id, None)
self._warned.pop(thread_id, None)
else:
self._history.clear()
self._warned.clear()

View File

@@ -0,0 +1,231 @@
"""Tests for LoopDetectionMiddleware."""
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, SystemMessage
from src.agents.middlewares.loop_detection_middleware import (
_HARD_STOP_MSG,
LoopDetectionMiddleware,
_hash_tool_calls,
)
def _make_runtime(thread_id="test-thread"):
"""Build a minimal Runtime mock with context."""
runtime = MagicMock()
runtime.context = {"thread_id": thread_id}
return runtime
def _make_state(tool_calls=None, content=""):
"""Build a minimal AgentState dict with an AIMessage."""
msg = AIMessage(content=content, tool_calls=tool_calls or [])
return {"messages": [msg]}
def _bash_call(cmd="ls"):
return {"name": "bash", "id": f"call_{cmd}", "args": {"command": cmd}}
class TestHashToolCalls:
def test_same_calls_same_hash(self):
a = _hash_tool_calls([_bash_call("ls")])
b = _hash_tool_calls([_bash_call("ls")])
assert a == b
def test_different_calls_different_hash(self):
a = _hash_tool_calls([_bash_call("ls")])
b = _hash_tool_calls([_bash_call("pwd")])
assert a != b
def test_order_independent(self):
a = _hash_tool_calls([_bash_call("ls"), {"name": "read_file", "args": {"path": "/tmp"}}])
b = _hash_tool_calls([{"name": "read_file", "args": {"path": "/tmp"}}, _bash_call("ls")])
assert a == b
def test_empty_calls(self):
h = _hash_tool_calls([])
assert isinstance(h, str)
assert len(h) > 0
class TestLoopDetection:
def test_no_tool_calls_returns_none(self):
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
state = {"messages": [AIMessage(content="hello")]}
result = mw._apply(state, runtime)
assert result is None
def test_below_threshold_returns_none(self):
mw = LoopDetectionMiddleware(warn_threshold=3)
runtime = _make_runtime()
call = [_bash_call("ls")]
# First two identical calls — no warning
for _ in range(2):
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
def test_warn_at_threshold(self):
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third identical call triggers warning
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
assert isinstance(msgs[0], SystemMessage)
assert "LOOP DETECTED" in msgs[0].content
def test_warn_only_injected_once(self):
"""Warning for the same hash should only be injected once per thread."""
mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10)
runtime = _make_runtime()
call = [_bash_call("ls")]
# First two — no warning
for _ in range(2):
mw._apply(_make_state(tool_calls=call), runtime)
# Third — warning injected
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Fourth — warning already injected, should return None
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
def test_hard_stop_at_limit(self):
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
# Fourth call triggers hard stop
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msgs = result["messages"]
assert len(msgs) == 1
# Hard stop strips tool_calls
assert isinstance(msgs[0], AIMessage)
assert msgs[0].tool_calls == []
assert _HARD_STOP_MSG in msgs[0].content
def test_different_calls_dont_trigger(self):
mw = LoopDetectionMiddleware(warn_threshold=2)
runtime = _make_runtime()
# Each call is different
for i in range(10):
result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime)
assert result is None
def test_window_sliding(self):
mw = LoopDetectionMiddleware(warn_threshold=3, window_size=5)
runtime = _make_runtime()
call = [_bash_call("ls")]
# Fill with 2 identical calls
mw._apply(_make_state(tool_calls=call), runtime)
mw._apply(_make_state(tool_calls=call), runtime)
# Push them out of the window with different calls
for i in range(5):
mw._apply(_make_state(tool_calls=[_bash_call(f"other_{i}")]), runtime)
# Now the original call should be fresh again — no warning
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
def test_reset_clears_state(self):
mw = LoopDetectionMiddleware(warn_threshold=2)
runtime = _make_runtime()
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
mw._apply(_make_state(tool_calls=call), runtime)
# Would trigger warning, but reset first
mw.reset()
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is None
def test_non_ai_message_ignored(self):
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
state = {"messages": [SystemMessage(content="hello")]}
result = mw._apply(state, runtime)
assert result is None
def test_empty_messages_ignored(self):
mw = LoopDetectionMiddleware()
runtime = _make_runtime()
result = mw._apply({"messages": []}, runtime)
assert result is None
def test_thread_id_from_runtime_context(self):
"""Thread ID should come from runtime.context, not state."""
mw = LoopDetectionMiddleware(warn_threshold=2)
runtime_a = _make_runtime("thread-A")
runtime_b = _make_runtime("thread-B")
call = [_bash_call("ls")]
# One call on thread A
mw._apply(_make_state(tool_calls=call), runtime_a)
# One call on thread B
mw._apply(_make_state(tool_calls=call), runtime_b)
# Second call on thread A — triggers warning (2 >= warn_threshold)
result = mw._apply(_make_state(tool_calls=call), runtime_a)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
# Second call on thread B — also triggers (independent tracking)
result = mw._apply(_make_state(tool_calls=call), runtime_b)
assert result is not None
assert "LOOP DETECTED" in result["messages"][0].content
def test_lru_eviction(self):
"""Old threads should be evicted when max_tracked_threads is exceeded."""
mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=3)
call = [_bash_call("ls")]
# Fill up 3 threads
for i in range(3):
runtime = _make_runtime(f"thread-{i}")
mw._apply(_make_state(tool_calls=call), runtime)
# Add a 4th thread — should evict thread-0
runtime_new = _make_runtime("thread-new")
mw._apply(_make_state(tool_calls=call), runtime_new)
assert "thread-0" not in mw._history
assert "thread-new" in mw._history
assert len(mw._history) == 3
def test_thread_safe_mutations(self):
"""Verify lock is used for mutations (basic structural test)."""
mw = LoopDetectionMiddleware()
# The middleware should have a lock attribute
assert hasattr(mw, "_lock")
assert isinstance(mw._lock, type(mw._lock))
def test_fallback_thread_id_when_missing(self):
"""When runtime context has no thread_id, should use 'default'."""
mw = LoopDetectionMiddleware(warn_threshold=2)
runtime = MagicMock()
runtime.context = {}
call = [_bash_call("ls")]
mw._apply(_make_state(tool_calls=call), runtime)
assert "default" in mw._history