"""Tests for DanglingToolCallMiddleware.""" from unittest.mock import AsyncMock, MagicMock import pytest from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from deerflow.agents.middlewares.dangling_tool_call_middleware import ( DanglingToolCallMiddleware, ) def _ai_with_tool_calls(tool_calls): return AIMessage(content="", tool_calls=tool_calls) def _tool_msg(tool_call_id, name="test_tool"): return ToolMessage(content="result", tool_call_id=tool_call_id, name=name) def _tc(name="bash", tc_id="call_1"): return {"name": name, "id": tc_id, "args": {}} class TestBuildPatchedMessagesNoPatch: def test_empty_messages(self): mw = DanglingToolCallMiddleware() assert mw._build_patched_messages([]) is None def test_no_ai_messages(self): mw = DanglingToolCallMiddleware() msgs = [HumanMessage(content="hello")] assert mw._build_patched_messages(msgs) is None def test_ai_without_tool_calls(self): mw = DanglingToolCallMiddleware() msgs = [AIMessage(content="hello")] assert mw._build_patched_messages(msgs) is None def test_all_tool_calls_responded(self): mw = DanglingToolCallMiddleware() msgs = [ _ai_with_tool_calls([_tc("bash", "call_1")]), _tool_msg("call_1", "bash"), ] assert mw._build_patched_messages(msgs) is None class TestBuildPatchedMessagesPatching: def test_single_dangling_call(self): mw = DanglingToolCallMiddleware() msgs = [_ai_with_tool_calls([_tc("bash", "call_1")])] patched = mw._build_patched_messages(msgs) assert patched is not None assert len(patched) == 2 assert isinstance(patched[1], ToolMessage) assert patched[1].tool_call_id == "call_1" assert patched[1].status == "error" def test_multiple_dangling_calls_same_message(self): mw = DanglingToolCallMiddleware() msgs = [ _ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]), ] patched = mw._build_patched_messages(msgs) assert patched is not None # Original AI + 2 synthetic ToolMessages assert len(patched) == 3 tool_msgs = [m for m in patched if isinstance(m, ToolMessage)] assert len(tool_msgs) == 2 assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "call_2"} def test_patch_inserted_after_offending_ai_message(self): mw = DanglingToolCallMiddleware() msgs = [ HumanMessage(content="hi"), _ai_with_tool_calls([_tc("bash", "call_1")]), HumanMessage(content="still here"), ] patched = mw._build_patched_messages(msgs) assert patched is not None # HumanMessage, AIMessage, synthetic ToolMessage, HumanMessage assert len(patched) == 4 assert isinstance(patched[0], HumanMessage) assert isinstance(patched[1], AIMessage) assert isinstance(patched[2], ToolMessage) assert patched[2].tool_call_id == "call_1" assert isinstance(patched[3], HumanMessage) def test_mixed_responded_and_dangling(self): mw = DanglingToolCallMiddleware() msgs = [ _ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]), _tool_msg("call_1", "bash"), ] patched = mw._build_patched_messages(msgs) assert patched is not None synthetic = [m for m in patched if isinstance(m, ToolMessage) and m.status == "error"] assert len(synthetic) == 1 assert synthetic[0].tool_call_id == "call_2" def test_multiple_ai_messages_each_patched(self): mw = DanglingToolCallMiddleware() msgs = [ _ai_with_tool_calls([_tc("bash", "call_1")]), HumanMessage(content="next turn"), _ai_with_tool_calls([_tc("read", "call_2")]), ] patched = mw._build_patched_messages(msgs) assert patched is not None synthetic = [m for m in patched if isinstance(m, ToolMessage)] assert len(synthetic) == 2 def test_synthetic_message_content(self): mw = DanglingToolCallMiddleware() msgs = [_ai_with_tool_calls([_tc("bash", "call_1")])] patched = mw._build_patched_messages(msgs) tool_msg = patched[1] assert "interrupted" in tool_msg.content.lower() assert tool_msg.name == "bash" class TestWrapModelCall: def test_no_patch_passthrough(self): mw = DanglingToolCallMiddleware() request = MagicMock() request.messages = [AIMessage(content="hello")] handler = MagicMock(return_value="response") result = mw.wrap_model_call(request, handler) handler.assert_called_once_with(request) assert result == "response" def test_patched_request_forwarded(self): mw = DanglingToolCallMiddleware() request = MagicMock() request.messages = [_ai_with_tool_calls([_tc("bash", "call_1")])] patched_request = MagicMock() request.override.return_value = patched_request handler = MagicMock(return_value="response") result = mw.wrap_model_call(request, handler) # Verify override was called with the patched messages request.override.assert_called_once() call_kwargs = request.override.call_args passed_messages = call_kwargs.kwargs["messages"] assert len(passed_messages) == 2 assert isinstance(passed_messages[1], ToolMessage) assert passed_messages[1].tool_call_id == "call_1" handler.assert_called_once_with(patched_request) assert result == "response" class TestAwrapModelCall: @pytest.mark.anyio async def test_async_no_patch(self): mw = DanglingToolCallMiddleware() request = MagicMock() request.messages = [AIMessage(content="hello")] handler = AsyncMock(return_value="response") result = await mw.awrap_model_call(request, handler) handler.assert_called_once_with(request) assert result == "response" @pytest.mark.anyio async def test_async_patched(self): mw = DanglingToolCallMiddleware() request = MagicMock() request.messages = [_ai_with_tool_calls([_tc("bash", "call_1")])] patched_request = MagicMock() request.override.return_value = patched_request handler = AsyncMock(return_value="response") result = await mw.awrap_model_call(request, handler) # Verify override was called with the patched messages request.override.assert_called_once() call_kwargs = request.override.call_args passed_messages = call_kwargs.kwargs["messages"] assert len(passed_messages) == 2 assert isinstance(passed_messages[1], ToolMessage) assert passed_messages[1].tool_call_id == "call_1" handler.assert_called_once_with(patched_request) assert result == "response"