diff --git a/backend/src/agents/lead_agent/agent.py b/backend/src/agents/lead_agent/agent.py index 29ca905..f407733 100644 --- a/backend/src/agents/lead_agent/agent.py +++ b/backend/src/agents/lead_agent/agent.py @@ -6,20 +6,17 @@ from langchain_core.runnables import RunnableConfig from src.agents.lead_agent.prompt import apply_prompt_template from src.agents.middlewares.clarification_middleware import ClarificationMiddleware -from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware from src.agents.middlewares.memory_middleware import MemoryMiddleware from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware -from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware from src.agents.middlewares.title_middleware import TitleMiddleware from src.agents.middlewares.todo_middleware import TodoMiddleware -from src.agents.middlewares.uploads_middleware import UploadsMiddleware +from src.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares from src.agents.middlewares.view_image_middleware import ViewImageMiddleware from src.agents.thread_state import ThreadState from src.config.agents_config import load_agent_config from src.config.app_config import get_app_config from src.config.summarization_config import get_summarization_config from src.models import create_chat_model -from src.sandbox.middleware import SandboxMiddleware logger = logging.getLogger(__name__) @@ -204,6 +201,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r # TitleMiddleware generates title after first exchange # MemoryMiddleware queues conversation for memory update (after TitleMiddleware) # ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM +# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages # ClarificationMiddleware should be last to intercept clarification requests after model calls def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None): """Build middleware chain based on runtime configuration. @@ -215,7 +213,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam Returns: List of middleware instances. """ - middlewares = [ThreadDataMiddleware(), UploadsMiddleware(), SandboxMiddleware(), DanglingToolCallMiddleware()] + middlewares = build_lead_runtime_middlewares(lazy_init=True) # Add summarization middleware if enabled summarization_middleware = _create_summarization_middleware() diff --git a/backend/src/agents/middlewares/tool_error_handling_middleware.py b/backend/src/agents/middlewares/tool_error_handling_middleware.py new file mode 100644 index 0000000..7a22fbc --- /dev/null +++ b/backend/src/agents/middlewares/tool_error_handling_middleware.py @@ -0,0 +1,115 @@ +"""Tool error handling middleware and shared runtime middleware builders.""" + +import logging +from collections.abc import Awaitable, Callable +from typing import override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ToolMessage +from langgraph.errors import GraphBubbleUp +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command + +logger = logging.getLogger(__name__) + +_MISSING_TOOL_CALL_ID = "missing_tool_call_id" + + +class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]): + """Convert tool exceptions into error ToolMessages so the run can continue.""" + + def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage: + tool_name = str(request.tool_call.get("name") or "unknown_tool") + tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID) + detail = str(exc).strip() or exc.__class__.__name__ + if len(detail) > 500: + detail = detail[:497] + "..." + + content = ( + f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. " + "Continue with available context, or choose an alternative tool." + ) + return ToolMessage( + content=content, + tool_call_id=tool_call_id, + name=tool_name, + status="error", + ) + + @override + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + try: + return handler(request) + except GraphBubbleUp: + # Preserve LangGraph control-flow signals (interrupt/pause/resume). + raise + except Exception as exc: + logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id")) + return self._build_error_message(request, exc) + + @override + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + try: + return await handler(request) + except GraphBubbleUp: + # Preserve LangGraph control-flow signals (interrupt/pause/resume). + raise + except Exception as exc: + logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id")) + return self._build_error_message(request, exc) + + +def _build_runtime_middlewares( + *, + include_uploads: bool, + include_dangling_tool_call_patch: bool, + lazy_init: bool = True, +) -> list[AgentMiddleware]: + """Build shared base middlewares for agent execution.""" + from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware + from src.sandbox.middleware import SandboxMiddleware + + middlewares: list[AgentMiddleware] = [ + ThreadDataMiddleware(lazy_init=lazy_init), + SandboxMiddleware(lazy_init=lazy_init), + ] + + if include_uploads: + from src.agents.middlewares.uploads_middleware import UploadsMiddleware + + middlewares.insert(1, UploadsMiddleware()) + + if include_dangling_tool_call_patch: + from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware + + middlewares.append(DanglingToolCallMiddleware()) + + middlewares.append(ToolErrorHandlingMiddleware()) + return middlewares + + +def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]: + """Middlewares shared by lead agent runtime before lead-only middlewares.""" + return _build_runtime_middlewares( + include_uploads=True, + include_dangling_tool_call_patch=True, + lazy_init=lazy_init, + ) + + +def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]: + """Middlewares shared by subagent runtime before subagent-only middlewares.""" + return _build_runtime_middlewares( + include_uploads=False, + include_dangling_tool_call_patch=False, + lazy_init=lazy_init, + ) diff --git a/backend/src/subagents/executor.py b/backend/src/subagents/executor.py index 1493133..b269dab 100644 --- a/backend/src/subagents/executor.py +++ b/backend/src/subagents/executor.py @@ -166,15 +166,10 @@ class SubagentExecutor: model_name = _get_model_name(self.config, self.parent_model) model = create_chat_model(name=model_name, thinking_enabled=False) - # Subagents need minimal middlewares to ensure tools can access sandbox and thread_data - # These middlewares will reuse the sandbox/thread_data from parent agent - from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware - from src.sandbox.middleware import SandboxMiddleware + from src.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares - middlewares = [ - ThreadDataMiddleware(lazy_init=True), # Compute thread paths - SandboxMiddleware(lazy_init=True), # Reuse parent's sandbox (no re-acquisition) - ] + # Reuse shared middleware composition with lead agent. + middlewares = build_subagent_runtime_middlewares(lazy_init=True) return create_agent( model=model, diff --git a/backend/tests/test_tool_error_handling_middleware.py b/backend/tests/test_tool_error_handling_middleware.py new file mode 100644 index 0000000..60e8981 --- /dev/null +++ b/backend/tests/test_tool_error_handling_middleware.py @@ -0,0 +1,96 @@ +from types import SimpleNamespace + +import pytest +from langchain_core.messages import ToolMessage +from langgraph.errors import GraphInterrupt + +from src.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware + + +def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"): + tool_call = {"name": name} + if tool_call_id is not None: + tool_call["id"] = tool_call_id + return SimpleNamespace(tool_call=tool_call) + + +def test_wrap_tool_call_passthrough_on_success(): + middleware = ToolErrorHandlingMiddleware() + req = _request() + expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search") + + result = middleware.wrap_tool_call(req, lambda _req: expected) + + assert result is expected + + +def test_wrap_tool_call_returns_error_tool_message_on_exception(): + middleware = ToolErrorHandlingMiddleware() + req = _request(name="web_search", tool_call_id="tc-42") + + def _boom(_req): + raise RuntimeError("network down") + + result = middleware.wrap_tool_call(req, _boom) + + assert isinstance(result, ToolMessage) + assert result.tool_call_id == "tc-42" + assert result.name == "web_search" + assert result.status == "error" + assert "Tool 'web_search' failed" in result.text + assert "network down" in result.text + + +def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing(): + middleware = ToolErrorHandlingMiddleware() + req = _request(name="mcp_tool", tool_call_id=None) + + def _boom(_req): + raise ValueError("bad request") + + result = middleware.wrap_tool_call(req, _boom) + + assert isinstance(result, ToolMessage) + assert result.tool_call_id == "missing_tool_call_id" + assert result.name == "mcp_tool" + assert result.status == "error" + + +def test_wrap_tool_call_reraises_graph_interrupt(): + middleware = ToolErrorHandlingMiddleware() + req = _request(name="ask_clarification", tool_call_id="tc-int") + + def _interrupt(_req): + raise GraphInterrupt(()) + + with pytest.raises(GraphInterrupt): + middleware.wrap_tool_call(req, _interrupt) + + +@pytest.mark.anyio +async def test_awrap_tool_call_returns_error_tool_message_on_exception(): + middleware = ToolErrorHandlingMiddleware() + req = _request(name="mcp_tool", tool_call_id="tc-async") + + async def _boom(_req): + raise TimeoutError("request timed out") + + result = await middleware.awrap_tool_call(req, _boom) + + assert isinstance(result, ToolMessage) + assert result.tool_call_id == "tc-async" + assert result.name == "mcp_tool" + assert result.status == "error" + assert "request timed out" in result.text + + +@pytest.mark.anyio +async def test_awrap_tool_call_reraises_graph_interrupt(): + middleware = ToolErrorHandlingMiddleware() + req = _request(name="ask_clarification", tool_call_id="tc-int-async") + + async def _interrupt(_req): + raise GraphInterrupt(()) + + with pytest.raises(GraphInterrupt): + await middleware.awrap_tool_call(req, _interrupt) diff --git a/scripts/tool-error-degradation-detection.sh b/scripts/tool-error-degradation-detection.sh new file mode 100755 index 0000000..3bc8c9a --- /dev/null +++ b/scripts/tool-error-degradation-detection.sh @@ -0,0 +1,218 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Detect whether the current branch has working tool-failure downgrade: +# - Lead agent middleware chain includes error-handling +# - Subagent middleware chain includes error-handling +# - Failing tool call does not abort the whole call sequence +# - Subsequent successful tool call result is still preserved + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +BACKEND_DIR="${ROOT_DIR}/backend" + +if ! command -v uv >/dev/null 2>&1; then + echo "[FAIL] uv is required but not found in PATH." + exit 1 +fi + +export UV_CACHE_DIR="${UV_CACHE_DIR:-/tmp/uv-cache}" + +echo "[INFO] Root: ${ROOT_DIR}" +echo "[INFO] Backend: ${BACKEND_DIR}" +echo "[INFO] UV cache: ${UV_CACHE_DIR}" +echo "[INFO] Running tool-failure downgrade detector..." + +cd "${BACKEND_DIR}" + +uv run python -u - <<'PY' +import asyncio +import logging +import ssl +from types import SimpleNamespace + +from requests.exceptions import SSLError + +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ToolMessage + +from src.agents.lead_agent.agent import _build_middlewares +from src.config import get_app_config +from src.sandbox.middleware import SandboxMiddleware + +from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware + +HANDSHAKE_ERROR = "[SSL: UNEXPECTED_EOF_WHILE_READING] EOF occurred in violation of protocol (_ssl.c:1000)" +logging.getLogger("src.agents.middlewares.tool_error_handling_middleware").setLevel(logging.CRITICAL) + + +def _make_ssl_error(): + return SSLError(ssl.SSLEOFError(8, HANDSHAKE_ERROR)) + +print("[STEP 1] Prepare simulated Tavily SSL handshake failure.") +print(f"[INFO] Handshake error payload: {HANDSHAKE_ERROR}") + +TOOL_CALLS = [ + {"name": "web_search", "id": "tc-fail", "args": {"query": "latest agent news"}}, + {"name": "web_fetch", "id": "tc-ok", "args": {"url": "https://example.com"}}, +] + + +def _sync_handler(req): + tool_name = req.tool_call.get("name", "unknown_tool") + if tool_name == "web_search": + raise _make_ssl_error() + return ToolMessage( + content=f"{tool_name} success", + tool_call_id=req.tool_call.get("id", "missing-id"), + name=tool_name, + status="success", + ) + + +async def _async_handler(req): + tool_name = req.tool_call.get("name", "unknown_tool") + if tool_name == "web_search": + raise _make_ssl_error() + return ToolMessage( + content=f"{tool_name} success", + tool_call_id=req.tool_call.get("id", "missing-id"), + name=tool_name, + status="success", + ) + + +def _collect_sync_wrappers(middlewares): + return [ + m.wrap_tool_call + for m in middlewares + if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call + or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call + ] + + +def _collect_async_wrappers(middlewares): + return [ + m.awrap_tool_call + for m in middlewares + if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call + or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call + ] + + +def _compose_sync(wrappers): + def execute(req): + return _sync_handler(req) + + for wrapper in reversed(wrappers): + previous = execute + + def execute(req, wrapper=wrapper, previous=previous): + return wrapper(req, previous) + + return execute + + +def _compose_async(wrappers): + async def execute(req): + return await _async_handler(req) + + for wrapper in reversed(wrappers): + previous = execute + + async def execute(req, wrapper=wrapper, previous=previous): + return await wrapper(req, previous) + + return execute + + +def _validate_outputs(label, outputs): + if len(outputs) != 2: + print(f"[FAIL] {label}: expected 2 tool outputs, got {len(outputs)}") + raise SystemExit(2) + first, second = outputs + if not isinstance(first, ToolMessage) or not isinstance(second, ToolMessage): + print(f"[FAIL] {label}: outputs are not ToolMessage instances") + raise SystemExit(3) + if first.status != "error": + print(f"[FAIL] {label}: first tool should be status=error, got {first.status}") + raise SystemExit(4) + if second.status != "success": + print(f"[FAIL] {label}: second tool should be status=success, got {second.status}") + raise SystemExit(5) + if "Error: Tool 'web_search' failed" not in first.text: + print(f"[FAIL] {label}: first tool error text missing") + raise SystemExit(6) + if "web_fetch success" not in second.text: + print(f"[FAIL] {label}: second tool success text missing") + raise SystemExit(7) + print(f"[INFO] {label}: no crash, outputs preserved (error + success).") + + +def _build_sub_middlewares(): + try: + from src.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares + except Exception: + return [ + ThreadDataMiddleware(lazy_init=True), + SandboxMiddleware(lazy_init=True), + ] + return build_subagent_runtime_middlewares() + + +def _run_sync_sequence(executor): + outputs = [] + try: + for call in TOOL_CALLS: + req = SimpleNamespace(tool_call=call) + outputs.append(executor(req)) + except Exception as exc: + return outputs, exc + return outputs, None + + +async def _run_async_sequence(executor): + outputs = [] + try: + for call in TOOL_CALLS: + req = SimpleNamespace(tool_call=call) + outputs.append(await executor(req)) + except Exception as exc: + return outputs, exc + return outputs, None + + +print("[STEP 2] Load current branch middleware chains.") +app_cfg = get_app_config() +model_name = app_cfg.models[0].name if app_cfg.models else None +if not model_name: + print("[FAIL] No model configured; cannot evaluate lead middleware chain.") + raise SystemExit(8) + +lead_middlewares = _build_middlewares({"configurable": {}}, model_name=model_name) +sub_middlewares = _build_sub_middlewares() + +print("[STEP 3] Simulate two sequential tool calls and check whether conversation flow aborts.") +any_crash = False +for label, middlewares in [("lead", lead_middlewares), ("subagent", sub_middlewares)]: + sync_exec = _compose_sync(_collect_sync_wrappers(middlewares)) + sync_outputs, sync_exc = _run_sync_sequence(sync_exec) + if sync_exc is not None: + any_crash = True + print(f"[INFO] {label}/sync: conversation aborted after tool error ({sync_exc.__class__.__name__}: {sync_exc}).") + else: + _validate_outputs(f"{label}/sync", sync_outputs) + + async_exec = _compose_async(_collect_async_wrappers(middlewares)) + async_outputs, async_exc = asyncio.run(_run_async_sequence(async_exec)) + if async_exc is not None: + any_crash = True + print(f"[INFO] {label}/async: conversation aborted after tool error ({async_exc.__class__.__name__}: {async_exc}).") + else: + _validate_outputs(f"{label}/async", async_outputs) + +if any_crash: + print("[FAIL] Tool exception caused conversation flow to abort (no effective downgrade).") + raise SystemExit(9) + +print("[PASS] Tool exceptions were downgraded; conversation flow continued with remaining tool results.") +PY