fix(middleware): degrade tool-call exceptions to error tool messages (#1110)

* fix(middleware): degrade tool-call exceptions to error tool messages

* update script

* fix(middleware): preserve LangGraph control-flow exceptions in tool error handling
This commit is contained in:
Liu Jice
2026-03-13 09:41:59 +08:00
committed by GitHub
parent 08ea9d3038
commit 3521cc2668
5 changed files with 435 additions and 13 deletions

View File

@@ -6,20 +6,17 @@ from langchain_core.runnables import RunnableConfig
from src.agents.lead_agent.prompt import apply_prompt_template
from src.agents.middlewares.clarification_middleware import ClarificationMiddleware
from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
from src.agents.middlewares.memory_middleware import MemoryMiddleware
from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from src.agents.middlewares.title_middleware import TitleMiddleware
from src.agents.middlewares.todo_middleware import TodoMiddleware
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
from src.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
from src.agents.middlewares.view_image_middleware import ViewImageMiddleware
from src.agents.thread_state import ThreadState
from src.config.agents_config import load_agent_config
from src.config.app_config import get_app_config
from src.config.summarization_config import get_summarization_config
from src.models import create_chat_model
from src.sandbox.middleware import SandboxMiddleware
logger = logging.getLogger(__name__)
@@ -204,6 +201,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# TitleMiddleware generates title after first exchange
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None):
"""Build middleware chain based on runtime configuration.
@@ -215,7 +213,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
Returns:
List of middleware instances.
"""
middlewares = [ThreadDataMiddleware(), UploadsMiddleware(), SandboxMiddleware(), DanglingToolCallMiddleware()]
middlewares = build_lead_runtime_middlewares(lazy_init=True)
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware()

View File

@@ -0,0 +1,115 @@
"""Tool error handling middleware and shared runtime middleware builders."""
import logging
from collections.abc import Awaitable, Callable
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
"""Convert tool exceptions into error ToolMessages so the run can continue."""
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
tool_name = str(request.tool_call.get("name") or "unknown_tool")
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
detail = str(exc).strip() or exc.__class__.__name__
if len(detail) > 500:
detail = detail[:497] + "..."
content = (
f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. "
"Continue with available context, or choose an alternative tool."
)
return ToolMessage(
content=content,
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
try:
return handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
try:
return await handler(request)
except GraphBubbleUp:
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
raise
except Exception as exc:
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
return self._build_error_message(request, exc)
def _build_runtime_middlewares(
*,
include_uploads: bool,
include_dangling_tool_call_patch: bool,
lazy_init: bool = True,
) -> list[AgentMiddleware]:
"""Build shared base middlewares for agent execution."""
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from src.sandbox.middleware import SandboxMiddleware
middlewares: list[AgentMiddleware] = [
ThreadDataMiddleware(lazy_init=lazy_init),
SandboxMiddleware(lazy_init=lazy_init),
]
if include_uploads:
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
middlewares.insert(1, UploadsMiddleware())
if include_dangling_tool_call_patch:
from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(ToolErrorHandlingMiddleware())
return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares(
include_uploads=True,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
include_uploads=False,
include_dangling_tool_call_patch=False,
lazy_init=lazy_init,
)

View File

@@ -166,15 +166,10 @@ class SubagentExecutor:
model_name = _get_model_name(self.config, self.parent_model)
model = create_chat_model(name=model_name, thinking_enabled=False)
# Subagents need minimal middlewares to ensure tools can access sandbox and thread_data
# These middlewares will reuse the sandbox/thread_data from parent agent
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from src.sandbox.middleware import SandboxMiddleware
from src.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
middlewares = [
ThreadDataMiddleware(lazy_init=True), # Compute thread paths
SandboxMiddleware(lazy_init=True), # Reuse parent's sandbox (no re-acquisition)
]
# Reuse shared middleware composition with lead agent.
middlewares = build_subagent_runtime_middlewares(lazy_init=True)
return create_agent(
model=model,

View File

@@ -0,0 +1,96 @@
from types import SimpleNamespace
import pytest
from langchain_core.messages import ToolMessage
from langgraph.errors import GraphInterrupt
from src.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
tool_call = {"name": name}
if tool_call_id is not None:
tool_call["id"] = tool_call_id
return SimpleNamespace(tool_call=tool_call)
def test_wrap_tool_call_passthrough_on_success():
middleware = ToolErrorHandlingMiddleware()
req = _request()
expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search")
result = middleware.wrap_tool_call(req, lambda _req: expected)
assert result is expected
def test_wrap_tool_call_returns_error_tool_message_on_exception():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="web_search", tool_call_id="tc-42")
def _boom(_req):
raise RuntimeError("network down")
result = middleware.wrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "tc-42"
assert result.name == "web_search"
assert result.status == "error"
assert "Tool 'web_search' failed" in result.text
assert "network down" in result.text
def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="mcp_tool", tool_call_id=None)
def _boom(_req):
raise ValueError("bad request")
result = middleware.wrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "missing_tool_call_id"
assert result.name == "mcp_tool"
assert result.status == "error"
def test_wrap_tool_call_reraises_graph_interrupt():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="ask_clarification", tool_call_id="tc-int")
def _interrupt(_req):
raise GraphInterrupt(())
with pytest.raises(GraphInterrupt):
middleware.wrap_tool_call(req, _interrupt)
@pytest.mark.anyio
async def test_awrap_tool_call_returns_error_tool_message_on_exception():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="mcp_tool", tool_call_id="tc-async")
async def _boom(_req):
raise TimeoutError("request timed out")
result = await middleware.awrap_tool_call(req, _boom)
assert isinstance(result, ToolMessage)
assert result.tool_call_id == "tc-async"
assert result.name == "mcp_tool"
assert result.status == "error"
assert "request timed out" in result.text
@pytest.mark.anyio
async def test_awrap_tool_call_reraises_graph_interrupt():
middleware = ToolErrorHandlingMiddleware()
req = _request(name="ask_clarification", tool_call_id="tc-int-async")
async def _interrupt(_req):
raise GraphInterrupt(())
with pytest.raises(GraphInterrupt):
await middleware.awrap_tool_call(req, _interrupt)

View File

@@ -0,0 +1,218 @@
#!/usr/bin/env bash
set -euo pipefail
# Detect whether the current branch has working tool-failure downgrade:
# - Lead agent middleware chain includes error-handling
# - Subagent middleware chain includes error-handling
# - Failing tool call does not abort the whole call sequence
# - Subsequent successful tool call result is still preserved
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
BACKEND_DIR="${ROOT_DIR}/backend"
if ! command -v uv >/dev/null 2>&1; then
echo "[FAIL] uv is required but not found in PATH."
exit 1
fi
export UV_CACHE_DIR="${UV_CACHE_DIR:-/tmp/uv-cache}"
echo "[INFO] Root: ${ROOT_DIR}"
echo "[INFO] Backend: ${BACKEND_DIR}"
echo "[INFO] UV cache: ${UV_CACHE_DIR}"
echo "[INFO] Running tool-failure downgrade detector..."
cd "${BACKEND_DIR}"
uv run python -u - <<'PY'
import asyncio
import logging
import ssl
from types import SimpleNamespace
from requests.exceptions import SSLError
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from src.agents.lead_agent.agent import _build_middlewares
from src.config import get_app_config
from src.sandbox.middleware import SandboxMiddleware
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
HANDSHAKE_ERROR = "[SSL: UNEXPECTED_EOF_WHILE_READING] EOF occurred in violation of protocol (_ssl.c:1000)"
logging.getLogger("src.agents.middlewares.tool_error_handling_middleware").setLevel(logging.CRITICAL)
def _make_ssl_error():
return SSLError(ssl.SSLEOFError(8, HANDSHAKE_ERROR))
print("[STEP 1] Prepare simulated Tavily SSL handshake failure.")
print(f"[INFO] Handshake error payload: {HANDSHAKE_ERROR}")
TOOL_CALLS = [
{"name": "web_search", "id": "tc-fail", "args": {"query": "latest agent news"}},
{"name": "web_fetch", "id": "tc-ok", "args": {"url": "https://example.com"}},
]
def _sync_handler(req):
tool_name = req.tool_call.get("name", "unknown_tool")
if tool_name == "web_search":
raise _make_ssl_error()
return ToolMessage(
content=f"{tool_name} success",
tool_call_id=req.tool_call.get("id", "missing-id"),
name=tool_name,
status="success",
)
async def _async_handler(req):
tool_name = req.tool_call.get("name", "unknown_tool")
if tool_name == "web_search":
raise _make_ssl_error()
return ToolMessage(
content=f"{tool_name} success",
tool_call_id=req.tool_call.get("id", "missing-id"),
name=tool_name,
status="success",
)
def _collect_sync_wrappers(middlewares):
return [
m.wrap_tool_call
for m in middlewares
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
]
def _collect_async_wrappers(middlewares):
return [
m.awrap_tool_call
for m in middlewares
if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
]
def _compose_sync(wrappers):
def execute(req):
return _sync_handler(req)
for wrapper in reversed(wrappers):
previous = execute
def execute(req, wrapper=wrapper, previous=previous):
return wrapper(req, previous)
return execute
def _compose_async(wrappers):
async def execute(req):
return await _async_handler(req)
for wrapper in reversed(wrappers):
previous = execute
async def execute(req, wrapper=wrapper, previous=previous):
return await wrapper(req, previous)
return execute
def _validate_outputs(label, outputs):
if len(outputs) != 2:
print(f"[FAIL] {label}: expected 2 tool outputs, got {len(outputs)}")
raise SystemExit(2)
first, second = outputs
if not isinstance(first, ToolMessage) or not isinstance(second, ToolMessage):
print(f"[FAIL] {label}: outputs are not ToolMessage instances")
raise SystemExit(3)
if first.status != "error":
print(f"[FAIL] {label}: first tool should be status=error, got {first.status}")
raise SystemExit(4)
if second.status != "success":
print(f"[FAIL] {label}: second tool should be status=success, got {second.status}")
raise SystemExit(5)
if "Error: Tool 'web_search' failed" not in first.text:
print(f"[FAIL] {label}: first tool error text missing")
raise SystemExit(6)
if "web_fetch success" not in second.text:
print(f"[FAIL] {label}: second tool success text missing")
raise SystemExit(7)
print(f"[INFO] {label}: no crash, outputs preserved (error + success).")
def _build_sub_middlewares():
try:
from src.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
except Exception:
return [
ThreadDataMiddleware(lazy_init=True),
SandboxMiddleware(lazy_init=True),
]
return build_subagent_runtime_middlewares()
def _run_sync_sequence(executor):
outputs = []
try:
for call in TOOL_CALLS:
req = SimpleNamespace(tool_call=call)
outputs.append(executor(req))
except Exception as exc:
return outputs, exc
return outputs, None
async def _run_async_sequence(executor):
outputs = []
try:
for call in TOOL_CALLS:
req = SimpleNamespace(tool_call=call)
outputs.append(await executor(req))
except Exception as exc:
return outputs, exc
return outputs, None
print("[STEP 2] Load current branch middleware chains.")
app_cfg = get_app_config()
model_name = app_cfg.models[0].name if app_cfg.models else None
if not model_name:
print("[FAIL] No model configured; cannot evaluate lead middleware chain.")
raise SystemExit(8)
lead_middlewares = _build_middlewares({"configurable": {}}, model_name=model_name)
sub_middlewares = _build_sub_middlewares()
print("[STEP 3] Simulate two sequential tool calls and check whether conversation flow aborts.")
any_crash = False
for label, middlewares in [("lead", lead_middlewares), ("subagent", sub_middlewares)]:
sync_exec = _compose_sync(_collect_sync_wrappers(middlewares))
sync_outputs, sync_exc = _run_sync_sequence(sync_exec)
if sync_exc is not None:
any_crash = True
print(f"[INFO] {label}/sync: conversation aborted after tool error ({sync_exc.__class__.__name__}: {sync_exc}).")
else:
_validate_outputs(f"{label}/sync", sync_outputs)
async_exec = _compose_async(_collect_async_wrappers(middlewares))
async_outputs, async_exc = asyncio.run(_run_async_sequence(async_exec))
if async_exc is not None:
any_crash = True
print(f"[INFO] {label}/async: conversation aborted after tool error ({async_exc.__class__.__name__}: {async_exc}).")
else:
_validate_outputs(f"{label}/async", async_outputs)
if any_crash:
print("[FAIL] Tool exception caused conversation flow to abort (no effective downgrade).")
raise SystemExit(9)
print("[PASS] Tool exceptions were downgraded; conversation flow continued with remaining tool results.")
PY