mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 14:22:13 +08:00
fix(middleware): degrade tool-call exceptions to error tool messages (#1110)
* fix(middleware): degrade tool-call exceptions to error tool messages * update script * fix(middleware): preserve LangGraph control-flow exceptions in tool error handling
This commit is contained in:
@@ -6,20 +6,17 @@ from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||
from src.agents.middlewares.clarification_middleware import ClarificationMiddleware
|
||||
from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
from src.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
from src.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.agents.middlewares.todo_middleware import TodoMiddleware
|
||||
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from src.agents.middlewares.tool_error_handling_middleware import build_lead_runtime_middlewares
|
||||
from src.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.config.agents_config import load_agent_config
|
||||
from src.config.app_config import get_app_config
|
||||
from src.config.summarization_config import get_summarization_config
|
||||
from src.models import create_chat_model
|
||||
from src.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -204,6 +201,7 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# TitleMiddleware generates title after first exchange
|
||||
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
@@ -215,7 +213,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
"""
|
||||
middlewares = [ThreadDataMiddleware(), UploadsMiddleware(), SandboxMiddleware(), DanglingToolCallMiddleware()]
|
||||
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware()
|
||||
|
||||
115
backend/src/agents/middlewares/tool_error_handling_middleware.py
Normal file
115
backend/src/agents/middlewares/tool_error_handling_middleware.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Tool error handling middleware and shared runtime middleware builders."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||
|
||||
|
||||
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
"""Convert tool exceptions into error ToolMessages so the run can continue."""
|
||||
|
||||
def _build_error_message(self, request: ToolCallRequest, exc: Exception) -> ToolMessage:
|
||||
tool_name = str(request.tool_call.get("name") or "unknown_tool")
|
||||
tool_call_id = str(request.tool_call.get("id") or _MISSING_TOOL_CALL_ID)
|
||||
detail = str(exc).strip() or exc.__class__.__name__
|
||||
if len(detail) > 500:
|
||||
detail = detail[:497] + "..."
|
||||
|
||||
content = (
|
||||
f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. "
|
||||
"Continue with available context, or choose an alternative tool."
|
||||
)
|
||||
return ToolMessage(
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
status="error",
|
||||
)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
try:
|
||||
return await handler(request)
|
||||
except GraphBubbleUp:
|
||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||
return self._build_error_message(request, exc)
|
||||
|
||||
|
||||
def _build_runtime_middlewares(
|
||||
*,
|
||||
include_uploads: bool,
|
||||
include_dangling_tool_call_patch: bool,
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Build shared base middlewares for agent execution."""
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
middlewares: list[AgentMiddleware] = [
|
||||
ThreadDataMiddleware(lazy_init=lazy_init),
|
||||
SandboxMiddleware(lazy_init=lazy_init),
|
||||
]
|
||||
|
||||
if include_uploads:
|
||||
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
|
||||
middlewares.insert(1, UploadsMiddleware())
|
||||
|
||||
if include_dangling_tool_call_patch:
|
||||
from src.agents.middlewares.dangling_tool_call_middleware import DanglingToolCallMiddleware
|
||||
|
||||
middlewares.append(DanglingToolCallMiddleware())
|
||||
|
||||
middlewares.append(ToolErrorHandlingMiddleware())
|
||||
return middlewares
|
||||
|
||||
|
||||
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=True,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
|
||||
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=False,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
@@ -166,15 +166,10 @@ class SubagentExecutor:
|
||||
model_name = _get_model_name(self.config, self.parent_model)
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
# Subagents need minimal middlewares to ensure tools can access sandbox and thread_data
|
||||
# These middlewares will reuse the sandbox/thread_data from parent agent
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.sandbox.middleware import SandboxMiddleware
|
||||
from src.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||
|
||||
middlewares = [
|
||||
ThreadDataMiddleware(lazy_init=True), # Compute thread paths
|
||||
SandboxMiddleware(lazy_init=True), # Reuse parent's sandbox (no re-acquisition)
|
||||
]
|
||||
# Reuse shared middleware composition with lead agent.
|
||||
middlewares = build_subagent_runtime_middlewares(lazy_init=True)
|
||||
|
||||
return create_agent(
|
||||
model=model,
|
||||
|
||||
96
backend/tests/test_tool_error_handling_middleware.py
Normal file
96
backend/tests/test_tool_error_handling_middleware.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
from src.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware
|
||||
|
||||
|
||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||
tool_call = {"name": name}
|
||||
if tool_call_id is not None:
|
||||
tool_call["id"] = tool_call_id
|
||||
return SimpleNamespace(tool_call=tool_call)
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_on_success():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request()
|
||||
expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search")
|
||||
|
||||
result = middleware.wrap_tool_call(req, lambda _req: expected)
|
||||
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_wrap_tool_call_returns_error_tool_message_on_exception():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="web_search", tool_call_id="tc-42")
|
||||
|
||||
def _boom(_req):
|
||||
raise RuntimeError("network down")
|
||||
|
||||
result = middleware.wrap_tool_call(req, _boom)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "tc-42"
|
||||
assert result.name == "web_search"
|
||||
assert result.status == "error"
|
||||
assert "Tool 'web_search' failed" in result.text
|
||||
assert "network down" in result.text
|
||||
|
||||
|
||||
def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="mcp_tool", tool_call_id=None)
|
||||
|
||||
def _boom(_req):
|
||||
raise ValueError("bad request")
|
||||
|
||||
result = middleware.wrap_tool_call(req, _boom)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "missing_tool_call_id"
|
||||
assert result.name == "mcp_tool"
|
||||
assert result.status == "error"
|
||||
|
||||
|
||||
def test_wrap_tool_call_reraises_graph_interrupt():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="ask_clarification", tool_call_id="tc-int")
|
||||
|
||||
def _interrupt(_req):
|
||||
raise GraphInterrupt(())
|
||||
|
||||
with pytest.raises(GraphInterrupt):
|
||||
middleware.wrap_tool_call(req, _interrupt)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_awrap_tool_call_returns_error_tool_message_on_exception():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="mcp_tool", tool_call_id="tc-async")
|
||||
|
||||
async def _boom(_req):
|
||||
raise TimeoutError("request timed out")
|
||||
|
||||
result = await middleware.awrap_tool_call(req, _boom)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "tc-async"
|
||||
assert result.name == "mcp_tool"
|
||||
assert result.status == "error"
|
||||
assert "request timed out" in result.text
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_awrap_tool_call_reraises_graph_interrupt():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="ask_clarification", tool_call_id="tc-int-async")
|
||||
|
||||
async def _interrupt(_req):
|
||||
raise GraphInterrupt(())
|
||||
|
||||
with pytest.raises(GraphInterrupt):
|
||||
await middleware.awrap_tool_call(req, _interrupt)
|
||||
218
scripts/tool-error-degradation-detection.sh
Executable file
218
scripts/tool-error-degradation-detection.sh
Executable file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Detect whether the current branch has working tool-failure downgrade:
|
||||
# - Lead agent middleware chain includes error-handling
|
||||
# - Subagent middleware chain includes error-handling
|
||||
# - Failing tool call does not abort the whole call sequence
|
||||
# - Subsequent successful tool call result is still preserved
|
||||
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
BACKEND_DIR="${ROOT_DIR}/backend"
|
||||
|
||||
if ! command -v uv >/dev/null 2>&1; then
|
||||
echo "[FAIL] uv is required but not found in PATH."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export UV_CACHE_DIR="${UV_CACHE_DIR:-/tmp/uv-cache}"
|
||||
|
||||
echo "[INFO] Root: ${ROOT_DIR}"
|
||||
echo "[INFO] Backend: ${BACKEND_DIR}"
|
||||
echo "[INFO] UV cache: ${UV_CACHE_DIR}"
|
||||
echo "[INFO] Running tool-failure downgrade detector..."
|
||||
|
||||
cd "${BACKEND_DIR}"
|
||||
|
||||
uv run python -u - <<'PY'
|
||||
import asyncio
|
||||
import logging
|
||||
import ssl
|
||||
from types import SimpleNamespace
|
||||
|
||||
from requests.exceptions import SSLError
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from src.agents.lead_agent.agent import _build_middlewares
|
||||
from src.config import get_app_config
|
||||
from src.sandbox.middleware import SandboxMiddleware
|
||||
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
|
||||
HANDSHAKE_ERROR = "[SSL: UNEXPECTED_EOF_WHILE_READING] EOF occurred in violation of protocol (_ssl.c:1000)"
|
||||
logging.getLogger("src.agents.middlewares.tool_error_handling_middleware").setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
def _make_ssl_error():
|
||||
return SSLError(ssl.SSLEOFError(8, HANDSHAKE_ERROR))
|
||||
|
||||
print("[STEP 1] Prepare simulated Tavily SSL handshake failure.")
|
||||
print(f"[INFO] Handshake error payload: {HANDSHAKE_ERROR}")
|
||||
|
||||
TOOL_CALLS = [
|
||||
{"name": "web_search", "id": "tc-fail", "args": {"query": "latest agent news"}},
|
||||
{"name": "web_fetch", "id": "tc-ok", "args": {"url": "https://example.com"}},
|
||||
]
|
||||
|
||||
|
||||
def _sync_handler(req):
|
||||
tool_name = req.tool_call.get("name", "unknown_tool")
|
||||
if tool_name == "web_search":
|
||||
raise _make_ssl_error()
|
||||
return ToolMessage(
|
||||
content=f"{tool_name} success",
|
||||
tool_call_id=req.tool_call.get("id", "missing-id"),
|
||||
name=tool_name,
|
||||
status="success",
|
||||
)
|
||||
|
||||
|
||||
async def _async_handler(req):
|
||||
tool_name = req.tool_call.get("name", "unknown_tool")
|
||||
if tool_name == "web_search":
|
||||
raise _make_ssl_error()
|
||||
return ToolMessage(
|
||||
content=f"{tool_name} success",
|
||||
tool_call_id=req.tool_call.get("id", "missing-id"),
|
||||
name=tool_name,
|
||||
status="success",
|
||||
)
|
||||
|
||||
|
||||
def _collect_sync_wrappers(middlewares):
|
||||
return [
|
||||
m.wrap_tool_call
|
||||
for m in middlewares
|
||||
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
||||
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
||||
]
|
||||
|
||||
|
||||
def _collect_async_wrappers(middlewares):
|
||||
return [
|
||||
m.awrap_tool_call
|
||||
for m in middlewares
|
||||
if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
||||
or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
||||
]
|
||||
|
||||
|
||||
def _compose_sync(wrappers):
|
||||
def execute(req):
|
||||
return _sync_handler(req)
|
||||
|
||||
for wrapper in reversed(wrappers):
|
||||
previous = execute
|
||||
|
||||
def execute(req, wrapper=wrapper, previous=previous):
|
||||
return wrapper(req, previous)
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _compose_async(wrappers):
|
||||
async def execute(req):
|
||||
return await _async_handler(req)
|
||||
|
||||
for wrapper in reversed(wrappers):
|
||||
previous = execute
|
||||
|
||||
async def execute(req, wrapper=wrapper, previous=previous):
|
||||
return await wrapper(req, previous)
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _validate_outputs(label, outputs):
|
||||
if len(outputs) != 2:
|
||||
print(f"[FAIL] {label}: expected 2 tool outputs, got {len(outputs)}")
|
||||
raise SystemExit(2)
|
||||
first, second = outputs
|
||||
if not isinstance(first, ToolMessage) or not isinstance(second, ToolMessage):
|
||||
print(f"[FAIL] {label}: outputs are not ToolMessage instances")
|
||||
raise SystemExit(3)
|
||||
if first.status != "error":
|
||||
print(f"[FAIL] {label}: first tool should be status=error, got {first.status}")
|
||||
raise SystemExit(4)
|
||||
if second.status != "success":
|
||||
print(f"[FAIL] {label}: second tool should be status=success, got {second.status}")
|
||||
raise SystemExit(5)
|
||||
if "Error: Tool 'web_search' failed" not in first.text:
|
||||
print(f"[FAIL] {label}: first tool error text missing")
|
||||
raise SystemExit(6)
|
||||
if "web_fetch success" not in second.text:
|
||||
print(f"[FAIL] {label}: second tool success text missing")
|
||||
raise SystemExit(7)
|
||||
print(f"[INFO] {label}: no crash, outputs preserved (error + success).")
|
||||
|
||||
|
||||
def _build_sub_middlewares():
|
||||
try:
|
||||
from src.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||
except Exception:
|
||||
return [
|
||||
ThreadDataMiddleware(lazy_init=True),
|
||||
SandboxMiddleware(lazy_init=True),
|
||||
]
|
||||
return build_subagent_runtime_middlewares()
|
||||
|
||||
|
||||
def _run_sync_sequence(executor):
|
||||
outputs = []
|
||||
try:
|
||||
for call in TOOL_CALLS:
|
||||
req = SimpleNamespace(tool_call=call)
|
||||
outputs.append(executor(req))
|
||||
except Exception as exc:
|
||||
return outputs, exc
|
||||
return outputs, None
|
||||
|
||||
|
||||
async def _run_async_sequence(executor):
|
||||
outputs = []
|
||||
try:
|
||||
for call in TOOL_CALLS:
|
||||
req = SimpleNamespace(tool_call=call)
|
||||
outputs.append(await executor(req))
|
||||
except Exception as exc:
|
||||
return outputs, exc
|
||||
return outputs, None
|
||||
|
||||
|
||||
print("[STEP 2] Load current branch middleware chains.")
|
||||
app_cfg = get_app_config()
|
||||
model_name = app_cfg.models[0].name if app_cfg.models else None
|
||||
if not model_name:
|
||||
print("[FAIL] No model configured; cannot evaluate lead middleware chain.")
|
||||
raise SystemExit(8)
|
||||
|
||||
lead_middlewares = _build_middlewares({"configurable": {}}, model_name=model_name)
|
||||
sub_middlewares = _build_sub_middlewares()
|
||||
|
||||
print("[STEP 3] Simulate two sequential tool calls and check whether conversation flow aborts.")
|
||||
any_crash = False
|
||||
for label, middlewares in [("lead", lead_middlewares), ("subagent", sub_middlewares)]:
|
||||
sync_exec = _compose_sync(_collect_sync_wrappers(middlewares))
|
||||
sync_outputs, sync_exc = _run_sync_sequence(sync_exec)
|
||||
if sync_exc is not None:
|
||||
any_crash = True
|
||||
print(f"[INFO] {label}/sync: conversation aborted after tool error ({sync_exc.__class__.__name__}: {sync_exc}).")
|
||||
else:
|
||||
_validate_outputs(f"{label}/sync", sync_outputs)
|
||||
|
||||
async_exec = _compose_async(_collect_async_wrappers(middlewares))
|
||||
async_outputs, async_exc = asyncio.run(_run_async_sequence(async_exec))
|
||||
if async_exc is not None:
|
||||
any_crash = True
|
||||
print(f"[INFO] {label}/async: conversation aborted after tool error ({async_exc.__class__.__name__}: {async_exc}).")
|
||||
else:
|
||||
_validate_outputs(f"{label}/async", async_outputs)
|
||||
|
||||
if any_crash:
|
||||
print("[FAIL] Tool exception caused conversation flow to abort (no effective downgrade).")
|
||||
raise SystemExit(9)
|
||||
|
||||
print("[PASS] Tool exceptions were downgraded; conversation flow continued with remaining tool results.")
|
||||
PY
|
||||
Reference in New Issue
Block a user