diff --git a/backend/src/agents/lead_agent/agent.py b/backend/src/agents/lead_agent/agent.py index f407733..de61cba 100644 --- a/backend/src/agents/lead_agent/agent.py +++ b/backend/src/agents/lead_agent/agent.py @@ -6,6 +6,7 @@ from langchain_core.runnables import RunnableConfig from src.agents.lead_agent.prompt import apply_prompt_template from src.agents.middlewares.clarification_middleware import ClarificationMiddleware +from src.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from src.agents.middlewares.memory_middleware import MemoryMiddleware from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware from src.agents.middlewares.title_middleware import TitleMiddleware @@ -245,6 +246,9 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3) middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents)) + # LoopDetectionMiddleware — detect and break repetitive tool call loops + middlewares.append(LoopDetectionMiddleware()) + # ClarificationMiddleware should always be last middlewares.append(ClarificationMiddleware()) return middlewares diff --git a/backend/src/agents/middlewares/loop_detection_middleware.py b/backend/src/agents/middlewares/loop_detection_middleware.py new file mode 100644 index 0000000..f96373e --- /dev/null +++ b/backend/src/agents/middlewares/loop_detection_middleware.py @@ -0,0 +1,227 @@ +"""Middleware to detect and break repetitive tool call loops. + +P0 safety: prevents the agent from calling the same tool with the same +arguments indefinitely until the recursion limit kills the run. + +Detection strategy: + 1. After each model response, hash the tool calls (name + args). + 2. Track recent hashes in a sliding window. + 3. If the same hash appears >= warn_threshold times, inject a + "you are repeating yourself — wrap up" system message (once per hash). + 4. If it appears >= hard_limit times, strip all tool_calls from the + response so the agent is forced to produce a final text answer. +""" + +import hashlib +import json +import logging +import threading +from collections import OrderedDict, defaultdict +from typing import override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import SystemMessage +from langgraph.runtime import Runtime + +logger = logging.getLogger(__name__) + +# Defaults — can be overridden via constructor +_DEFAULT_WARN_THRESHOLD = 3 # inject warning after 3 identical calls +_DEFAULT_HARD_LIMIT = 5 # force-stop after 5 identical calls +_DEFAULT_WINDOW_SIZE = 20 # track last N tool calls +_DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit + + +def _hash_tool_calls(tool_calls: list[dict]) -> str: + """Deterministic hash of a set of tool calls (name + args). + + This is intended to be order-independent: the same multiset of tool calls + should always produce the same hash, regardless of their input order. + """ + # First normalize each tool call to a minimal (name, args) structure. + normalized: list[dict] = [] + for tc in tool_calls: + normalized.append( + { + "name": tc.get("name", ""), + "args": tc.get("args", {}), + } + ) + + # Sort by both name and a deterministic serialization of args so that + # permutations of the same multiset of calls yield the same ordering. + normalized.sort( + key=lambda tc: ( + tc["name"], + json.dumps(tc["args"], sort_keys=True, default=str), + ) + ) + blob = json.dumps(normalized, sort_keys=True, default=str) + return hashlib.md5(blob.encode()).hexdigest()[:12] + + +_WARNING_MSG = ( + "[LOOP DETECTED] You are repeating the same tool calls. " + "Stop calling tools and produce your final answer now. " + "If you cannot complete the task, summarize what you accomplished so far." +) + +_HARD_STOP_MSG = ( + "[FORCED STOP] Repeated tool calls exceeded the safety limit. " + "Producing final answer with results collected so far." +) + + +class LoopDetectionMiddleware(AgentMiddleware[AgentState]): + """Detects and breaks repetitive tool call loops. + + Args: + warn_threshold: Number of identical tool call sets before injecting + a warning message. Default: 3. + hard_limit: Number of identical tool call sets before stripping + tool_calls entirely. Default: 5. + window_size: Size of the sliding window for tracking calls. + Default: 20. + max_tracked_threads: Maximum number of threads to track before + evicting the least recently used. Default: 100. + """ + + def __init__( + self, + warn_threshold: int = _DEFAULT_WARN_THRESHOLD, + hard_limit: int = _DEFAULT_HARD_LIMIT, + window_size: int = _DEFAULT_WINDOW_SIZE, + max_tracked_threads: int = _DEFAULT_MAX_TRACKED_THREADS, + ): + super().__init__() + self.warn_threshold = warn_threshold + self.hard_limit = hard_limit + self.window_size = window_size + self.max_tracked_threads = max_tracked_threads + self._lock = threading.Lock() + # Per-thread tracking using OrderedDict for LRU eviction + self._history: OrderedDict[str, list[str]] = OrderedDict() + self._warned: dict[str, set[str]] = defaultdict(set) + + def _get_thread_id(self, runtime: Runtime) -> str: + """Extract thread_id from runtime context for per-thread tracking.""" + thread_id = runtime.context.get("thread_id") + if thread_id: + return thread_id + return "default" + + def _evict_if_needed(self) -> None: + """Evict least recently used threads if over the limit. + + Must be called while holding self._lock. + """ + while len(self._history) > self.max_tracked_threads: + evicted_id, _ = self._history.popitem(last=False) + self._warned.pop(evicted_id, None) + logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id) + + def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]: + """Track tool calls and check for loops. + + Returns: + (warning_message_or_none, should_hard_stop) + """ + messages = state.get("messages", []) + if not messages: + return None, False + + last_msg = messages[-1] + if getattr(last_msg, "type", None) != "ai": + return None, False + + tool_calls = getattr(last_msg, "tool_calls", None) + if not tool_calls: + return None, False + + thread_id = self._get_thread_id(runtime) + call_hash = _hash_tool_calls(tool_calls) + + with self._lock: + # Touch / create entry (move to end for LRU) + if thread_id in self._history: + self._history.move_to_end(thread_id) + else: + self._history[thread_id] = [] + self._evict_if_needed() + + history = self._history[thread_id] + history.append(call_hash) + if len(history) > self.window_size: + history[:] = history[-self.window_size:] + + count = history.count(call_hash) + tool_names = [tc.get("name", "?") for tc in tool_calls] + + if count >= self.hard_limit: + logger.error( + "Loop hard limit reached — forcing stop", + extra={ + "thread_id": thread_id, + "call_hash": call_hash, + "count": count, + "tools": tool_names, + }, + ) + return _HARD_STOP_MSG, True + + if count >= self.warn_threshold: + warned = self._warned[thread_id] + if call_hash not in warned: + warned.add(call_hash) + logger.warning( + "Repetitive tool calls detected — injecting warning", + extra={ + "thread_id": thread_id, + "call_hash": call_hash, + "count": count, + "tools": tool_names, + }, + ) + return _WARNING_MSG, False + # Warning already injected for this hash — suppress + return None, False + + return None, False + + def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: + warning, hard_stop = self._track_and_check(state, runtime) + + if hard_stop: + # Strip tool_calls from the last AIMessage to force text output + messages = state.get("messages", []) + last_msg = messages[-1] + stripped_msg = last_msg.model_copy(update={ + "tool_calls": [], + "content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}", + }) + return {"messages": [stripped_msg]} + + if warning: + # Inject a system message warning the model + return {"messages": [SystemMessage(content=warning)]} + + return None + + @override + def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state, runtime) + + @override + async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state, runtime) + + def reset(self, thread_id: str | None = None) -> None: + """Clear tracking state. If thread_id given, clear only that thread.""" + with self._lock: + if thread_id: + self._history.pop(thread_id, None) + self._warned.pop(thread_id, None) + else: + self._history.clear() + self._warned.clear() diff --git a/backend/tests/test_loop_detection_middleware.py b/backend/tests/test_loop_detection_middleware.py new file mode 100644 index 0000000..dfcfa41 --- /dev/null +++ b/backend/tests/test_loop_detection_middleware.py @@ -0,0 +1,231 @@ +"""Tests for LoopDetectionMiddleware.""" + +from unittest.mock import MagicMock + +from langchain_core.messages import AIMessage, SystemMessage + +from src.agents.middlewares.loop_detection_middleware import ( + _HARD_STOP_MSG, + LoopDetectionMiddleware, + _hash_tool_calls, +) + + +def _make_runtime(thread_id="test-thread"): + """Build a minimal Runtime mock with context.""" + runtime = MagicMock() + runtime.context = {"thread_id": thread_id} + return runtime + + +def _make_state(tool_calls=None, content=""): + """Build a minimal AgentState dict with an AIMessage.""" + msg = AIMessage(content=content, tool_calls=tool_calls or []) + return {"messages": [msg]} + + +def _bash_call(cmd="ls"): + return {"name": "bash", "id": f"call_{cmd}", "args": {"command": cmd}} + + +class TestHashToolCalls: + def test_same_calls_same_hash(self): + a = _hash_tool_calls([_bash_call("ls")]) + b = _hash_tool_calls([_bash_call("ls")]) + assert a == b + + def test_different_calls_different_hash(self): + a = _hash_tool_calls([_bash_call("ls")]) + b = _hash_tool_calls([_bash_call("pwd")]) + assert a != b + + def test_order_independent(self): + a = _hash_tool_calls([_bash_call("ls"), {"name": "read_file", "args": {"path": "/tmp"}}]) + b = _hash_tool_calls([{"name": "read_file", "args": {"path": "/tmp"}}, _bash_call("ls")]) + assert a == b + + def test_empty_calls(self): + h = _hash_tool_calls([]) + assert isinstance(h, str) + assert len(h) > 0 + + +class TestLoopDetection: + def test_no_tool_calls_returns_none(self): + mw = LoopDetectionMiddleware() + runtime = _make_runtime() + state = {"messages": [AIMessage(content="hello")]} + result = mw._apply(state, runtime) + assert result is None + + def test_below_threshold_returns_none(self): + mw = LoopDetectionMiddleware(warn_threshold=3) + runtime = _make_runtime() + call = [_bash_call("ls")] + + # First two identical calls — no warning + for _ in range(2): + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is None + + def test_warn_at_threshold(self): + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5) + runtime = _make_runtime() + call = [_bash_call("ls")] + + for _ in range(2): + mw._apply(_make_state(tool_calls=call), runtime) + + # Third identical call triggers warning + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is not None + msgs = result["messages"] + assert len(msgs) == 1 + assert isinstance(msgs[0], SystemMessage) + assert "LOOP DETECTED" in msgs[0].content + + def test_warn_only_injected_once(self): + """Warning for the same hash should only be injected once per thread.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime = _make_runtime() + call = [_bash_call("ls")] + + # First two — no warning + for _ in range(2): + mw._apply(_make_state(tool_calls=call), runtime) + + # Third — warning injected + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is not None + assert "LOOP DETECTED" in result["messages"][0].content + + # Fourth — warning already injected, should return None + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is None + + def test_hard_stop_at_limit(self): + mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) + runtime = _make_runtime() + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) + + # Fourth call triggers hard stop + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is not None + msgs = result["messages"] + assert len(msgs) == 1 + # Hard stop strips tool_calls + assert isinstance(msgs[0], AIMessage) + assert msgs[0].tool_calls == [] + assert _HARD_STOP_MSG in msgs[0].content + + def test_different_calls_dont_trigger(self): + mw = LoopDetectionMiddleware(warn_threshold=2) + runtime = _make_runtime() + + # Each call is different + for i in range(10): + result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) + assert result is None + + def test_window_sliding(self): + mw = LoopDetectionMiddleware(warn_threshold=3, window_size=5) + runtime = _make_runtime() + call = [_bash_call("ls")] + + # Fill with 2 identical calls + mw._apply(_make_state(tool_calls=call), runtime) + mw._apply(_make_state(tool_calls=call), runtime) + + # Push them out of the window with different calls + for i in range(5): + mw._apply(_make_state(tool_calls=[_bash_call(f"other_{i}")]), runtime) + + # Now the original call should be fresh again — no warning + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is None + + def test_reset_clears_state(self): + mw = LoopDetectionMiddleware(warn_threshold=2) + runtime = _make_runtime() + call = [_bash_call("ls")] + + mw._apply(_make_state(tool_calls=call), runtime) + mw._apply(_make_state(tool_calls=call), runtime) + + # Would trigger warning, but reset first + mw.reset() + result = mw._apply(_make_state(tool_calls=call), runtime) + assert result is None + + def test_non_ai_message_ignored(self): + mw = LoopDetectionMiddleware() + runtime = _make_runtime() + state = {"messages": [SystemMessage(content="hello")]} + result = mw._apply(state, runtime) + assert result is None + + def test_empty_messages_ignored(self): + mw = LoopDetectionMiddleware() + runtime = _make_runtime() + result = mw._apply({"messages": []}, runtime) + assert result is None + + def test_thread_id_from_runtime_context(self): + """Thread ID should come from runtime.context, not state.""" + mw = LoopDetectionMiddleware(warn_threshold=2) + runtime_a = _make_runtime("thread-A") + runtime_b = _make_runtime("thread-B") + call = [_bash_call("ls")] + + # One call on thread A + mw._apply(_make_state(tool_calls=call), runtime_a) + # One call on thread B + mw._apply(_make_state(tool_calls=call), runtime_b) + + # Second call on thread A — triggers warning (2 >= warn_threshold) + result = mw._apply(_make_state(tool_calls=call), runtime_a) + assert result is not None + assert "LOOP DETECTED" in result["messages"][0].content + + # Second call on thread B — also triggers (independent tracking) + result = mw._apply(_make_state(tool_calls=call), runtime_b) + assert result is not None + assert "LOOP DETECTED" in result["messages"][0].content + + def test_lru_eviction(self): + """Old threads should be evicted when max_tracked_threads is exceeded.""" + mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=3) + call = [_bash_call("ls")] + + # Fill up 3 threads + for i in range(3): + runtime = _make_runtime(f"thread-{i}") + mw._apply(_make_state(tool_calls=call), runtime) + + # Add a 4th thread — should evict thread-0 + runtime_new = _make_runtime("thread-new") + mw._apply(_make_state(tool_calls=call), runtime_new) + + assert "thread-0" not in mw._history + assert "thread-new" in mw._history + assert len(mw._history) == 3 + + def test_thread_safe_mutations(self): + """Verify lock is used for mutations (basic structural test).""" + mw = LoopDetectionMiddleware() + # The middleware should have a lock attribute + assert hasattr(mw, "_lock") + assert isinstance(mw._lock, type(mw._lock)) + + def test_fallback_thread_id_when_missing(self): + """When runtime context has no thread_id, should use 'default'.""" + mw = LoopDetectionMiddleware(warn_threshold=2) + runtime = MagicMock() + runtime.context = {} + call = [_bash_call("ls")] + + mw._apply(_make_state(tool_calls=call), runtime) + assert "default" in mw._history