From ec46ae075d8f91fbff9b0d227f76180e3ab64e49 Mon Sep 17 00:00:00 2001 From: Andrew Barnes Date: Tue, 24 Mar 2026 22:20:16 -0400 Subject: [PATCH] test: add unit tests for SubagentLimitMiddleware (#1306) * test: add unit tests for SubagentLimitMiddleware Cover subagent limit enforcement: - _clamp_subagent_limit boundary clamping - Task call truncation when exceeding limit - Non-task tool calls preserved during truncation - after_model/aafter_model delegation * Update test_subagent_limit_middleware.py * Fix import statement for MAX_CONCURRENT_SUBAGENTS --------- Co-authored-by: Willem Jiang --- .../tests/test_subagent_limit_middleware.py | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 backend/tests/test_subagent_limit_middleware.py diff --git a/backend/tests/test_subagent_limit_middleware.py b/backend/tests/test_subagent_limit_middleware.py new file mode 100644 index 0000000..c331c3a --- /dev/null +++ b/backend/tests/test_subagent_limit_middleware.py @@ -0,0 +1,140 @@ +"""Tests for SubagentLimitMiddleware.""" + +from unittest.mock import MagicMock + +from langchain_core.messages import AIMessage, HumanMessage + +from deerflow.agents.middlewares.subagent_limit_middleware import ( + MAX_CONCURRENT_SUBAGENTS, + MAX_SUBAGENT_LIMIT, + MIN_SUBAGENT_LIMIT, + SubagentLimitMiddleware, + _clamp_subagent_limit, +) + + +def _make_runtime(): + runtime = MagicMock() + runtime.context = {"thread_id": "test-thread"} + return runtime + + +def _task_call(task_id="call_1"): + return {"name": "task", "id": task_id, "args": {"prompt": "do something"}} + + +def _other_call(name="bash", call_id="call_other"): + return {"name": name, "id": call_id, "args": {}} + + +class TestClampSubagentLimit: + def test_below_min_clamped_to_min(self): + assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT + assert _clamp_subagent_limit(1) == MIN_SUBAGENT_LIMIT + + def test_above_max_clamped_to_max(self): + assert _clamp_subagent_limit(10) == MAX_SUBAGENT_LIMIT + assert _clamp_subagent_limit(100) == MAX_SUBAGENT_LIMIT + + def test_within_range_unchanged(self): + assert _clamp_subagent_limit(2) == 2 + assert _clamp_subagent_limit(3) == 3 + assert _clamp_subagent_limit(4) == 4 + + +class TestSubagentLimitMiddlewareInit: + def test_default_max_concurrent(self): + mw = SubagentLimitMiddleware() + assert mw.max_concurrent == MAX_CONCURRENT_SUBAGENTS + + def test_custom_max_concurrent_clamped(self): + mw = SubagentLimitMiddleware(max_concurrent=1) + assert mw.max_concurrent == MIN_SUBAGENT_LIMIT + + mw = SubagentLimitMiddleware(max_concurrent=10) + assert mw.max_concurrent == MAX_SUBAGENT_LIMIT + + +class TestTruncateTaskCalls: + def test_no_messages_returns_none(self): + mw = SubagentLimitMiddleware() + assert mw._truncate_task_calls({"messages": []}) is None + + def test_missing_messages_returns_none(self): + mw = SubagentLimitMiddleware() + assert mw._truncate_task_calls({}) is None + + def test_last_message_not_ai_returns_none(self): + mw = SubagentLimitMiddleware() + state = {"messages": [HumanMessage(content="hello")]} + assert mw._truncate_task_calls(state) is None + + def test_ai_no_tool_calls_returns_none(self): + mw = SubagentLimitMiddleware() + state = {"messages": [AIMessage(content="thinking...")]} + assert mw._truncate_task_calls(state) is None + + def test_task_calls_within_limit_returns_none(self): + mw = SubagentLimitMiddleware(max_concurrent=3) + msg = AIMessage( + content="", + tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3")], + ) + assert mw._truncate_task_calls({"messages": [msg]}) is None + + def test_task_calls_exceeding_limit_truncated(self): + mw = SubagentLimitMiddleware(max_concurrent=2) + msg = AIMessage( + content="", + tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")], + ) + result = mw._truncate_task_calls({"messages": [msg]}) + assert result is not None + updated_msg = result["messages"][0] + task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"] + assert len(task_calls) == 2 + assert task_calls[0]["id"] == "t1" + assert task_calls[1]["id"] == "t2" + + def test_non_task_calls_preserved(self): + mw = SubagentLimitMiddleware(max_concurrent=2) + msg = AIMessage( + content="", + tool_calls=[ + _other_call("bash", "b1"), + _task_call("t1"), + _task_call("t2"), + _task_call("t3"), + _other_call("read", "r1"), + ], + ) + result = mw._truncate_task_calls({"messages": [msg]}) + assert result is not None + updated_msg = result["messages"][0] + names = [tc["name"] for tc in updated_msg.tool_calls] + assert "bash" in names + assert "read" in names + task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"] + assert len(task_calls) == 2 + + def test_only_non_task_calls_returns_none(self): + mw = SubagentLimitMiddleware() + msg = AIMessage( + content="", + tool_calls=[_other_call("bash", "b1"), _other_call("read", "r1")], + ) + assert mw._truncate_task_calls({"messages": [msg]}) is None + + +class TestAfterModel: + def test_delegates_to_truncate(self): + mw = SubagentLimitMiddleware(max_concurrent=2) + runtime = _make_runtime() + msg = AIMessage( + content="", + tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3")], + ) + result = mw.after_model({"messages": [msg]}, runtime) + assert result is not None + task_calls = [tc for tc in result["messages"][0].tool_calls if tc["name"] == "task"] + assert len(task_calls) == 2