mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
fix(middleware): degrade tool-call exceptions to error tool messages (#1110)
* fix(middleware): degrade tool-call exceptions to error tool messages * update script * fix(middleware): preserve LangGraph control-flow exceptions in tool error handling
This commit is contained in:
96
backend/tests/test_tool_error_handling_middleware.py
Normal file
96
backend/tests/test_tool_error_handling_middleware.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
from src.agents.middlewares.tool_error_handling_middleware import ToolErrorHandlingMiddleware
|
||||
|
||||
|
||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||
tool_call = {"name": name}
|
||||
if tool_call_id is not None:
|
||||
tool_call["id"] = tool_call_id
|
||||
return SimpleNamespace(tool_call=tool_call)
|
||||
|
||||
|
||||
def test_wrap_tool_call_passthrough_on_success():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request()
|
||||
expected = ToolMessage(content="ok", tool_call_id="tc-1", name="web_search")
|
||||
|
||||
result = middleware.wrap_tool_call(req, lambda _req: expected)
|
||||
|
||||
assert result is expected
|
||||
|
||||
|
||||
def test_wrap_tool_call_returns_error_tool_message_on_exception():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="web_search", tool_call_id="tc-42")
|
||||
|
||||
def _boom(_req):
|
||||
raise RuntimeError("network down")
|
||||
|
||||
result = middleware.wrap_tool_call(req, _boom)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "tc-42"
|
||||
assert result.name == "web_search"
|
||||
assert result.status == "error"
|
||||
assert "Tool 'web_search' failed" in result.text
|
||||
assert "network down" in result.text
|
||||
|
||||
|
||||
def test_wrap_tool_call_uses_fallback_tool_call_id_when_missing():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="mcp_tool", tool_call_id=None)
|
||||
|
||||
def _boom(_req):
|
||||
raise ValueError("bad request")
|
||||
|
||||
result = middleware.wrap_tool_call(req, _boom)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "missing_tool_call_id"
|
||||
assert result.name == "mcp_tool"
|
||||
assert result.status == "error"
|
||||
|
||||
|
||||
def test_wrap_tool_call_reraises_graph_interrupt():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="ask_clarification", tool_call_id="tc-int")
|
||||
|
||||
def _interrupt(_req):
|
||||
raise GraphInterrupt(())
|
||||
|
||||
with pytest.raises(GraphInterrupt):
|
||||
middleware.wrap_tool_call(req, _interrupt)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_awrap_tool_call_returns_error_tool_message_on_exception():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="mcp_tool", tool_call_id="tc-async")
|
||||
|
||||
async def _boom(_req):
|
||||
raise TimeoutError("request timed out")
|
||||
|
||||
result = await middleware.awrap_tool_call(req, _boom)
|
||||
|
||||
assert isinstance(result, ToolMessage)
|
||||
assert result.tool_call_id == "tc-async"
|
||||
assert result.name == "mcp_tool"
|
||||
assert result.status == "error"
|
||||
assert "request timed out" in result.text
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_awrap_tool_call_reraises_graph_interrupt():
|
||||
middleware = ToolErrorHandlingMiddleware()
|
||||
req = _request(name="ask_clarification", tool_call_id="tc-int-async")
|
||||
|
||||
async def _interrupt(_req):
|
||||
raise GraphInterrupt(())
|
||||
|
||||
with pytest.raises(GraphInterrupt):
|
||||
await middleware.awrap_tool_call(req, _interrupt)
|
||||
Reference in New Issue
Block a user