mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
test: add unit tests for SubagentLimitMiddleware (#1306)
* test: add unit tests for SubagentLimitMiddleware Cover subagent limit enforcement: - _clamp_subagent_limit boundary clamping - Task call truncation when exceeding limit - Non-task tool calls preserved during truncation - after_model/aafter_model delegation * Update test_subagent_limit_middleware.py * Fix import statement for MAX_CONCURRENT_SUBAGENTS --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
140
backend/tests/test_subagent_limit_middleware.py
Normal file
140
backend/tests/test_subagent_limit_middleware.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Tests for SubagentLimitMiddleware."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares.subagent_limit_middleware import (
|
||||
MAX_CONCURRENT_SUBAGENTS,
|
||||
MAX_SUBAGENT_LIMIT,
|
||||
MIN_SUBAGENT_LIMIT,
|
||||
SubagentLimitMiddleware,
|
||||
_clamp_subagent_limit,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "test-thread"}
|
||||
return runtime
|
||||
|
||||
|
||||
def _task_call(task_id="call_1"):
|
||||
return {"name": "task", "id": task_id, "args": {"prompt": "do something"}}
|
||||
|
||||
|
||||
def _other_call(name="bash", call_id="call_other"):
|
||||
return {"name": name, "id": call_id, "args": {}}
|
||||
|
||||
|
||||
class TestClampSubagentLimit:
|
||||
def test_below_min_clamped_to_min(self):
|
||||
assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT
|
||||
assert _clamp_subagent_limit(1) == MIN_SUBAGENT_LIMIT
|
||||
|
||||
def test_above_max_clamped_to_max(self):
|
||||
assert _clamp_subagent_limit(10) == MAX_SUBAGENT_LIMIT
|
||||
assert _clamp_subagent_limit(100) == MAX_SUBAGENT_LIMIT
|
||||
|
||||
def test_within_range_unchanged(self):
|
||||
assert _clamp_subagent_limit(2) == 2
|
||||
assert _clamp_subagent_limit(3) == 3
|
||||
assert _clamp_subagent_limit(4) == 4
|
||||
|
||||
|
||||
class TestSubagentLimitMiddlewareInit:
|
||||
def test_default_max_concurrent(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
assert mw.max_concurrent == MAX_CONCURRENT_SUBAGENTS
|
||||
|
||||
def test_custom_max_concurrent_clamped(self):
|
||||
mw = SubagentLimitMiddleware(max_concurrent=1)
|
||||
assert mw.max_concurrent == MIN_SUBAGENT_LIMIT
|
||||
|
||||
mw = SubagentLimitMiddleware(max_concurrent=10)
|
||||
assert mw.max_concurrent == MAX_SUBAGENT_LIMIT
|
||||
|
||||
|
||||
class TestTruncateTaskCalls:
|
||||
def test_no_messages_returns_none(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
assert mw._truncate_task_calls({"messages": []}) is None
|
||||
|
||||
def test_missing_messages_returns_none(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
assert mw._truncate_task_calls({}) is None
|
||||
|
||||
def test_last_message_not_ai_returns_none(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
state = {"messages": [HumanMessage(content="hello")]}
|
||||
assert mw._truncate_task_calls(state) is None
|
||||
|
||||
def test_ai_no_tool_calls_returns_none(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
state = {"messages": [AIMessage(content="thinking...")]}
|
||||
assert mw._truncate_task_calls(state) is None
|
||||
|
||||
def test_task_calls_within_limit_returns_none(self):
|
||||
mw = SubagentLimitMiddleware(max_concurrent=3)
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3")],
|
||||
)
|
||||
assert mw._truncate_task_calls({"messages": [msg]}) is None
|
||||
|
||||
def test_task_calls_exceeding_limit_truncated(self):
|
||||
mw = SubagentLimitMiddleware(max_concurrent=2)
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")],
|
||||
)
|
||||
result = mw._truncate_task_calls({"messages": [msg]})
|
||||
assert result is not None
|
||||
updated_msg = result["messages"][0]
|
||||
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
|
||||
assert len(task_calls) == 2
|
||||
assert task_calls[0]["id"] == "t1"
|
||||
assert task_calls[1]["id"] == "t2"
|
||||
|
||||
def test_non_task_calls_preserved(self):
|
||||
mw = SubagentLimitMiddleware(max_concurrent=2)
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
_other_call("bash", "b1"),
|
||||
_task_call("t1"),
|
||||
_task_call("t2"),
|
||||
_task_call("t3"),
|
||||
_other_call("read", "r1"),
|
||||
],
|
||||
)
|
||||
result = mw._truncate_task_calls({"messages": [msg]})
|
||||
assert result is not None
|
||||
updated_msg = result["messages"][0]
|
||||
names = [tc["name"] for tc in updated_msg.tool_calls]
|
||||
assert "bash" in names
|
||||
assert "read" in names
|
||||
task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"]
|
||||
assert len(task_calls) == 2
|
||||
|
||||
def test_only_non_task_calls_returns_none(self):
|
||||
mw = SubagentLimitMiddleware()
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[_other_call("bash", "b1"), _other_call("read", "r1")],
|
||||
)
|
||||
assert mw._truncate_task_calls({"messages": [msg]}) is None
|
||||
|
||||
|
||||
class TestAfterModel:
|
||||
def test_delegates_to_truncate(self):
|
||||
mw = SubagentLimitMiddleware(max_concurrent=2)
|
||||
runtime = _make_runtime()
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3")],
|
||||
)
|
||||
result = mw.after_model({"messages": [msg]}, runtime)
|
||||
assert result is not None
|
||||
task_calls = [tc for tc in result["messages"][0].tool_calls if tc["name"] == "task"]
|
||||
assert len(task_calls) == 2
|
||||
Reference in New Issue
Block a user