mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-04 06:32:13 +08:00
* test: add unit tests for DanglingToolCallMiddleware Cover message patching logic for dangling tool calls: - No-op when all tool calls have responses - Synthetic ToolMessage insertion at correct positions - Mixed responded/dangling scenarios - wrap_model_call and awrap_model_call integration * test: fix async tests and strengthen override assertions - Use @pytest.mark.anyio + async def instead of deprecated asyncio.get_event_loop().run_until_complete() (fixes Py3.12 CI failure) - Assert that override() receives the correct patched messages kwarg in both wrap_model_call and awrap_model_call tests --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
191 lines
6.9 KiB
Python
191 lines
6.9 KiB
Python
"""Tests for DanglingToolCallMiddleware."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
|
|
|
from deerflow.agents.middlewares.dangling_tool_call_middleware import (
|
|
DanglingToolCallMiddleware,
|
|
)
|
|
|
|
|
|
def _ai_with_tool_calls(tool_calls):
|
|
return AIMessage(content="", tool_calls=tool_calls)
|
|
|
|
|
|
def _tool_msg(tool_call_id, name="test_tool"):
|
|
return ToolMessage(content="result", tool_call_id=tool_call_id, name=name)
|
|
|
|
|
|
def _tc(name="bash", tc_id="call_1"):
|
|
return {"name": name, "id": tc_id, "args": {}}
|
|
|
|
|
|
class TestBuildPatchedMessagesNoPatch:
|
|
def test_empty_messages(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
assert mw._build_patched_messages([]) is None
|
|
|
|
def test_no_ai_messages(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [HumanMessage(content="hello")]
|
|
assert mw._build_patched_messages(msgs) is None
|
|
|
|
def test_ai_without_tool_calls(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [AIMessage(content="hello")]
|
|
assert mw._build_patched_messages(msgs) is None
|
|
|
|
def test_all_tool_calls_responded(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [
|
|
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
|
_tool_msg("call_1", "bash"),
|
|
]
|
|
assert mw._build_patched_messages(msgs) is None
|
|
|
|
|
|
class TestBuildPatchedMessagesPatching:
|
|
def test_single_dangling_call(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
|
patched = mw._build_patched_messages(msgs)
|
|
assert patched is not None
|
|
assert len(patched) == 2
|
|
assert isinstance(patched[1], ToolMessage)
|
|
assert patched[1].tool_call_id == "call_1"
|
|
assert patched[1].status == "error"
|
|
|
|
def test_multiple_dangling_calls_same_message(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [
|
|
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
|
]
|
|
patched = mw._build_patched_messages(msgs)
|
|
assert patched is not None
|
|
# Original AI + 2 synthetic ToolMessages
|
|
assert len(patched) == 3
|
|
tool_msgs = [m for m in patched if isinstance(m, ToolMessage)]
|
|
assert len(tool_msgs) == 2
|
|
assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "call_2"}
|
|
|
|
def test_patch_inserted_after_offending_ai_message(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [
|
|
HumanMessage(content="hi"),
|
|
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
|
HumanMessage(content="still here"),
|
|
]
|
|
patched = mw._build_patched_messages(msgs)
|
|
assert patched is not None
|
|
# HumanMessage, AIMessage, synthetic ToolMessage, HumanMessage
|
|
assert len(patched) == 4
|
|
assert isinstance(patched[0], HumanMessage)
|
|
assert isinstance(patched[1], AIMessage)
|
|
assert isinstance(patched[2], ToolMessage)
|
|
assert patched[2].tool_call_id == "call_1"
|
|
assert isinstance(patched[3], HumanMessage)
|
|
|
|
def test_mixed_responded_and_dangling(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [
|
|
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
|
_tool_msg("call_1", "bash"),
|
|
]
|
|
patched = mw._build_patched_messages(msgs)
|
|
assert patched is not None
|
|
synthetic = [m for m in patched if isinstance(m, ToolMessage) and m.status == "error"]
|
|
assert len(synthetic) == 1
|
|
assert synthetic[0].tool_call_id == "call_2"
|
|
|
|
def test_multiple_ai_messages_each_patched(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [
|
|
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
|
HumanMessage(content="next turn"),
|
|
_ai_with_tool_calls([_tc("read", "call_2")]),
|
|
]
|
|
patched = mw._build_patched_messages(msgs)
|
|
assert patched is not None
|
|
synthetic = [m for m in patched if isinstance(m, ToolMessage)]
|
|
assert len(synthetic) == 2
|
|
|
|
def test_synthetic_message_content(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
msgs = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
|
patched = mw._build_patched_messages(msgs)
|
|
tool_msg = patched[1]
|
|
assert "interrupted" in tool_msg.content.lower()
|
|
assert tool_msg.name == "bash"
|
|
|
|
|
|
class TestWrapModelCall:
|
|
def test_no_patch_passthrough(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
request = MagicMock()
|
|
request.messages = [AIMessage(content="hello")]
|
|
handler = MagicMock(return_value="response")
|
|
|
|
result = mw.wrap_model_call(request, handler)
|
|
|
|
handler.assert_called_once_with(request)
|
|
assert result == "response"
|
|
|
|
def test_patched_request_forwarded(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
request = MagicMock()
|
|
request.messages = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
|
patched_request = MagicMock()
|
|
request.override.return_value = patched_request
|
|
handler = MagicMock(return_value="response")
|
|
|
|
result = mw.wrap_model_call(request, handler)
|
|
|
|
# Verify override was called with the patched messages
|
|
request.override.assert_called_once()
|
|
call_kwargs = request.override.call_args
|
|
passed_messages = call_kwargs.kwargs["messages"]
|
|
assert len(passed_messages) == 2
|
|
assert isinstance(passed_messages[1], ToolMessage)
|
|
assert passed_messages[1].tool_call_id == "call_1"
|
|
|
|
handler.assert_called_once_with(patched_request)
|
|
assert result == "response"
|
|
|
|
|
|
class TestAwrapModelCall:
|
|
@pytest.mark.anyio
|
|
async def test_async_no_patch(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
request = MagicMock()
|
|
request.messages = [AIMessage(content="hello")]
|
|
handler = AsyncMock(return_value="response")
|
|
|
|
result = await mw.awrap_model_call(request, handler)
|
|
|
|
handler.assert_called_once_with(request)
|
|
assert result == "response"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_async_patched(self):
|
|
mw = DanglingToolCallMiddleware()
|
|
request = MagicMock()
|
|
request.messages = [_ai_with_tool_calls([_tc("bash", "call_1")])]
|
|
patched_request = MagicMock()
|
|
request.override.return_value = patched_request
|
|
handler = AsyncMock(return_value="response")
|
|
|
|
result = await mw.awrap_model_call(request, handler)
|
|
|
|
# Verify override was called with the patched messages
|
|
request.override.assert_called_once()
|
|
call_kwargs = request.override.call_args
|
|
passed_messages = call_kwargs.kwargs["messages"]
|
|
assert len(passed_messages) == 2
|
|
assert isinstance(passed_messages[1], ToolMessage)
|
|
assert passed_messages[1].tool_call_id == "call_1"
|
|
|
|
handler.assert_called_once_with(patched_request)
|
|
assert result == "response"
|