mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
test: add unit tests for DanglingToolCallMiddleware (#1305)
* test: add unit tests for DanglingToolCallMiddleware Cover message patching logic for dangling tool calls: - No-op when all tool calls have responses - Synthetic ToolMessage insertion at correct positions - Mixed responded/dangling scenarios - wrap_model_call and awrap_model_call integration * test: fix async tests and strengthen override assertions - Use @pytest.mark.anyio + async def instead of deprecated asyncio.get_event_loop().run_until_complete() (fixes Py3.12 CI failure) - Assert that override() receives the correct patched messages kwarg in both wrap_model_call and awrap_model_call tests --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
190
backend/tests/test_dangling_tool_call_middleware.py
Normal file
190
backend/tests/test_dangling_tool_call_middleware.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Tests for DanglingToolCallMiddleware."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.middlewares.dangling_tool_call_middleware import (
|
||||
DanglingToolCallMiddleware,
|
||||
)
|
||||
|
||||
|
||||
def _ai_with_tool_calls(tool_calls):
|
||||
return AIMessage(content="", tool_calls=tool_calls)
|
||||
|
||||
|
||||
def _tool_msg(tool_call_id, name="test_tool"):
|
||||
return ToolMessage(content="result", tool_call_id=tool_call_id, name=name)
|
||||
|
||||
|
||||
def _tc(name="bash", tc_id="call_1"):
|
||||
return {"name": name, "id": tc_id, "args": {}}
|
||||
|
||||
|
||||
class TestBuildPatchedMessagesNoPatch:
|
||||
def test_empty_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
assert mw._build_patched_messages([]) is None
|
||||
|
||||
def test_no_ai_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [HumanMessage(content="hello")]
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_ai_without_tool_calls(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [AIMessage(content="hello")]
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_all_tool_calls_responded(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
|
||||
class TestBuildPatchedMessagesPatching:
|
||||
def test_single_dangling_call(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
assert len(patched) == 2
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert patched[1].status == "error"
|
||||
|
||||
def test_multiple_dangling_calls_same_message(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
||||
]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
# Original AI + 2 synthetic ToolMessages
|
||||
assert len(patched) == 3
|
||||
tool_msgs = [m for m in patched if isinstance(m, ToolMessage)]
|
||||
assert len(tool_msgs) == 2
|
||||
assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "call_2"}
|
||||
|
||||
def test_patch_inserted_after_offending_ai_message(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
HumanMessage(content="hi"),
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="still here"),
|
||||
]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
# HumanMessage, AIMessage, synthetic ToolMessage, HumanMessage
|
||||
assert len(patched) == 4
|
||||
assert isinstance(patched[0], HumanMessage)
|
||||
assert isinstance(patched[1], AIMessage)
|
||||
assert isinstance(patched[2], ToolMessage)
|
||||
assert patched[2].tool_call_id == "call_1"
|
||||
assert isinstance(patched[3], HumanMessage)
|
||||
|
||||
def test_mixed_responded_and_dangling(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
synthetic = [m for m in patched if isinstance(m, ToolMessage) and m.status == "error"]
|
||||
assert len(synthetic) == 1
|
||||
assert synthetic[0].tool_call_id == "call_2"
|
||||
|
||||
def test_multiple_ai_messages_each_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="next turn"),
|
||||
_ai_with_tool_calls([_tc("read", "call_2")]),
|
||||
]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
synthetic = [m for m in patched if isinstance(m, ToolMessage)]
|
||||
assert len(synthetic) == 2
|
||||
|
||||
def test_synthetic_message_content(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
tool_msg = patched[1]
|
||||
assert "interrupted" in tool_msg.content.lower()
|
||||
assert tool_msg.name == "bash"
|
||||
|
||||
|
||||
class TestWrapModelCall:
|
||||
def test_no_patch_passthrough(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
request = MagicMock()
|
||||
request.messages = [AIMessage(content="hello")]
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
result = mw.wrap_model_call(request, handler)
|
||||
|
||||
handler.assert_called_once_with(request)
|
||||
assert result == "response"
|
||||
|
||||
def test_patched_request_forwarded(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
request = MagicMock()
|
||||
request.messages = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
||||
patched_request = MagicMock()
|
||||
request.override.return_value = patched_request
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
result = mw.wrap_model_call(request, handler)
|
||||
|
||||
# Verify override was called with the patched messages
|
||||
request.override.assert_called_once()
|
||||
call_kwargs = request.override.call_args
|
||||
passed_messages = call_kwargs.kwargs["messages"]
|
||||
assert len(passed_messages) == 2
|
||||
assert isinstance(passed_messages[1], ToolMessage)
|
||||
assert passed_messages[1].tool_call_id == "call_1"
|
||||
|
||||
handler.assert_called_once_with(patched_request)
|
||||
assert result == "response"
|
||||
|
||||
|
||||
class TestAwrapModelCall:
|
||||
@pytest.mark.anyio
|
||||
async def test_async_no_patch(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
request = MagicMock()
|
||||
request.messages = [AIMessage(content="hello")]
|
||||
handler = AsyncMock(return_value="response")
|
||||
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
|
||||
handler.assert_called_once_with(request)
|
||||
assert result == "response"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_async_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
request = MagicMock()
|
||||
request.messages = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
||||
patched_request = MagicMock()
|
||||
request.override.return_value = patched_request
|
||||
handler = AsyncMock(return_value="response")
|
||||
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
|
||||
# Verify override was called with the patched messages
|
||||
request.override.assert_called_once()
|
||||
call_kwargs = request.override.call_args
|
||||
passed_messages = call_kwargs.kwargs["messages"]
|
||||
assert len(passed_messages) == 2
|
||||
assert isinstance(passed_messages[1], ToolMessage)
|
||||
assert passed_messages[1].tool_call_id == "call_1"
|
||||
|
||||
handler.assert_called_once_with(patched_request)
|
||||
assert result == "response"
|
||||
Reference in New Issue
Block a user