mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
fix(subagent): support async MCP tools in subagent executor (#917)
* fix(subagent): support async MCP tools in subagent executor
SubagentExecutor.execute() was synchronous and could not handle async-only tools like MCP tools. This caused failures when trying to use MCP tools within subagents.
Changes:
- Add _aexecute() async method using agent.astream() for async execution
- Refactor execute() to use asyncio.run() wrapping _aexecute()
- This allows subagents to use async tools (MCP) within ThreadPoolExecutor
* test(subagent): add unit tests for executor async/sync paths
Add comprehensive tests covering:
- Async _aexecute() with success/error cases
- Sync execute() wrapper using asyncio.run()
- Async tool (MCP) support verification
- Thread pool execution safety
* fix(subagent): subagent-test-circular-depend
- Use session-scoped fixture with delayed import to handle circular dependencies
without affecting other test modules
---------
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""Subagent execution engine."""
|
"""Subagent execution engine."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
@@ -204,8 +205,8 @@ class SubagentExecutor:
|
|||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
async def _aexecute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||||
"""Execute a task synchronously.
|
"""Execute a task asynchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: The task description for the subagent.
|
task: The task description for the subagent.
|
||||||
@@ -240,12 +241,12 @@ class SubagentExecutor:
|
|||||||
run_config["configurable"] = {"thread_id": self.thread_id}
|
run_config["configurable"] = {"thread_id": self.thread_id}
|
||||||
context["thread_id"] = self.thread_id
|
context["thread_id"] = self.thread_id
|
||||||
|
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting execution with max_turns={self.config.max_turns}")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}")
|
||||||
|
|
||||||
# Use stream instead of invoke to get real-time updates
|
# Use stream instead of invoke to get real-time updates
|
||||||
# This allows us to collect AI messages as they are generated
|
# This allows us to collect AI messages as they are generated
|
||||||
final_state = None
|
final_state = None
|
||||||
for chunk in agent.stream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||||
final_state = chunk
|
final_state = chunk
|
||||||
|
|
||||||
# Extract AI messages from the current state
|
# Extract AI messages from the current state
|
||||||
@@ -269,7 +270,7 @@ class SubagentExecutor:
|
|||||||
result.ai_messages.append(message_dict)
|
result.ai_messages.append(message_dict)
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
|
||||||
|
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed execution")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
|
||||||
|
|
||||||
if final_state is None:
|
if final_state is None:
|
||||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
||||||
@@ -315,13 +316,53 @@ class SubagentExecutor:
|
|||||||
result.completed_at = datetime.now()
|
result.completed_at = datetime.now()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} execution failed")
|
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed")
|
||||||
result.status = SubagentStatus.FAILED
|
result.status = SubagentStatus.FAILED
|
||||||
result.error = str(e)
|
result.error = str(e)
|
||||||
result.completed_at = datetime.now()
|
result.completed_at = datetime.now()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||||
|
"""Execute a task synchronously (wrapper around async execution).
|
||||||
|
|
||||||
|
This method runs the async execution in a new event loop, allowing
|
||||||
|
asynchronous tools (like MCP tools) to be used within the thread pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task description for the subagent.
|
||||||
|
result_holder: Optional pre-created result object to update during execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SubagentResult with the execution result.
|
||||||
|
"""
|
||||||
|
# Run the async execution in a new event loop
|
||||||
|
# This is necessary because:
|
||||||
|
# 1. We may have async-only tools (like MCP tools)
|
||||||
|
# 2. We're running inside a ThreadPoolExecutor which doesn't have an event loop
|
||||||
|
#
|
||||||
|
# Note: _aexecute() catches all exceptions internally, so this outer
|
||||||
|
# try-except only handles asyncio.run() failures (e.g., if called from
|
||||||
|
# an async context where an event loop already exists). Subagent execution
|
||||||
|
# errors are handled within _aexecute() and returned as FAILED status.
|
||||||
|
try:
|
||||||
|
return asyncio.run(self._aexecute(task, result_holder))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} execution failed")
|
||||||
|
# Create a result with error if we don't have one
|
||||||
|
if result_holder is not None:
|
||||||
|
result = result_holder
|
||||||
|
else:
|
||||||
|
result = SubagentResult(
|
||||||
|
task_id=str(uuid.uuid4())[:8],
|
||||||
|
trace_id=self.trace_id,
|
||||||
|
status=SubagentStatus.FAILED,
|
||||||
|
)
|
||||||
|
result.status = SubagentStatus.FAILED
|
||||||
|
result.error = str(e)
|
||||||
|
result.completed_at = datetime.now()
|
||||||
|
return result
|
||||||
|
|
||||||
def execute_async(self, task: str, task_id: str | None = None) -> str:
|
def execute_async(self, task: str, task_id: str | None = None) -> str:
|
||||||
"""Start a task execution in the background.
|
"""Start a task execution in the background.
|
||||||
|
|
||||||
|
|||||||
627
backend/tests/test_subagent_executor.py
Normal file
627
backend/tests/test_subagent_executor.py
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
"""Tests for subagent executor async/sync execution paths.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- SubagentExecutor.execute() synchronous execution path
|
||||||
|
- SubagentExecutor._aexecute() asynchronous execution path
|
||||||
|
- asyncio.run() properly executes async workflow within thread pool context
|
||||||
|
- Error handling in both sync and async paths
|
||||||
|
- Async tool support (MCP tools)
|
||||||
|
|
||||||
|
Note: Due to circular import issues in the main codebase, conftest.py mocks
|
||||||
|
src.subagents.executor. This test file uses delayed import via fixture to test
|
||||||
|
the real implementation in isolation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Module names that need to be mocked to break circular imports
|
||||||
|
_MOCKED_MODULE_NAMES = [
|
||||||
|
"src.agents",
|
||||||
|
"src.agents.thread_state",
|
||||||
|
"src.agents.middlewares",
|
||||||
|
"src.agents.middlewares.thread_data_middleware",
|
||||||
|
"src.sandbox",
|
||||||
|
"src.sandbox.middleware",
|
||||||
|
"src.models",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def _setup_executor_classes():
|
||||||
|
"""Set up mocked modules and import real executor classes.
|
||||||
|
|
||||||
|
This fixture runs once per session and yields the executor classes.
|
||||||
|
It handles module cleanup to avoid affecting other test files.
|
||||||
|
"""
|
||||||
|
# Save original modules
|
||||||
|
original_modules = {name: sys.modules.get(name) for name in _MOCKED_MODULE_NAMES}
|
||||||
|
original_executor = sys.modules.get("src.subagents.executor")
|
||||||
|
|
||||||
|
# Remove mocked executor if exists (from conftest.py)
|
||||||
|
if "src.subagents.executor" in sys.modules:
|
||||||
|
del sys.modules["src.subagents.executor"]
|
||||||
|
|
||||||
|
# Set up mocks
|
||||||
|
for name in _MOCKED_MODULE_NAMES:
|
||||||
|
sys.modules[name] = MagicMock()
|
||||||
|
|
||||||
|
# Import real classes inside fixture
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
from src.subagents.config import SubagentConfig
|
||||||
|
from src.subagents.executor import (
|
||||||
|
SubagentExecutor,
|
||||||
|
SubagentResult,
|
||||||
|
SubagentStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store classes in a dict to yield
|
||||||
|
classes = {
|
||||||
|
"AIMessage": AIMessage,
|
||||||
|
"HumanMessage": HumanMessage,
|
||||||
|
"SubagentConfig": SubagentConfig,
|
||||||
|
"SubagentExecutor": SubagentExecutor,
|
||||||
|
"SubagentResult": SubagentResult,
|
||||||
|
"SubagentStatus": SubagentStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
yield classes
|
||||||
|
|
||||||
|
# Cleanup: Restore original modules
|
||||||
|
for name in _MOCKED_MODULE_NAMES:
|
||||||
|
if original_modules[name] is not None:
|
||||||
|
sys.modules[name] = original_modules[name]
|
||||||
|
elif name in sys.modules:
|
||||||
|
del sys.modules[name]
|
||||||
|
|
||||||
|
# Restore executor module (conftest.py mock)
|
||||||
|
if original_executor is not None:
|
||||||
|
sys.modules["src.subagents.executor"] = original_executor
|
||||||
|
elif "src.subagents.executor" in sys.modules:
|
||||||
|
del sys.modules["src.subagents.executor"]
|
||||||
|
|
||||||
|
|
||||||
|
# Helper classes that wrap real classes for testing
|
||||||
|
class MockHumanMessage:
|
||||||
|
"""Mock HumanMessage for testing - wraps real class from fixture."""
|
||||||
|
|
||||||
|
def __init__(self, content, _classes=None):
|
||||||
|
self._content = content
|
||||||
|
self._classes = _classes
|
||||||
|
|
||||||
|
def _get_real(self):
|
||||||
|
return self._classes["HumanMessage"](content=self._content)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAIMessage:
|
||||||
|
"""Mock AIMessage for testing - wraps real class from fixture."""
|
||||||
|
|
||||||
|
def __init__(self, content, msg_id=None, _classes=None):
|
||||||
|
self._content = content
|
||||||
|
self._msg_id = msg_id
|
||||||
|
self._classes = _classes
|
||||||
|
|
||||||
|
def _get_real(self):
|
||||||
|
msg = self._classes["AIMessage"](content=self._content)
|
||||||
|
if self._msg_id:
|
||||||
|
msg.id = self._msg_id
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
async def async_iterator(items):
|
||||||
|
"""Helper to create an async iterator from a list."""
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def classes(_setup_executor_classes):
|
||||||
|
"""Provide access to executor classes."""
|
||||||
|
return _setup_executor_classes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_config(classes):
|
||||||
|
"""Return a basic subagent config for testing."""
|
||||||
|
return classes["SubagentConfig"](
|
||||||
|
name="test-agent",
|
||||||
|
description="Test agent",
|
||||||
|
system_prompt="You are a test agent.",
|
||||||
|
max_turns=10,
|
||||||
|
timeout_seconds=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent():
|
||||||
|
"""Return a properly configured mock agent with async stream."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.astream = MagicMock()
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
# Helper to create real message objects
|
||||||
|
class _MsgHelper:
|
||||||
|
"""Helper to create real message objects from fixture classes."""
|
||||||
|
|
||||||
|
def __init__(self, classes):
|
||||||
|
self.classes = classes
|
||||||
|
|
||||||
|
def human(self, content):
|
||||||
|
return self.classes["HumanMessage"](content=content)
|
||||||
|
|
||||||
|
def ai(self, content, msg_id=None):
|
||||||
|
msg = self.classes["AIMessage"](content=content)
|
||||||
|
if msg_id:
|
||||||
|
msg.id = msg_id
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def msg(classes):
|
||||||
|
"""Provide message factory."""
|
||||||
|
return _MsgHelper(classes)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Async Execution Path Tests
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncExecutionPath:
|
||||||
|
"""Test _aexecute() async execution path."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aexecute_success(self, classes, base_config, mock_agent, msg):
|
||||||
|
"""Test successful async execution returns completed result."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
final_message = msg.ai("Task completed successfully", "msg-1")
|
||||||
|
final_state = {
|
||||||
|
"messages": [
|
||||||
|
msg.human("Do something"),
|
||||||
|
final_message,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
trace_id="test-trace",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Do something")
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert result.result == "Task completed successfully"
|
||||||
|
assert result.error is None
|
||||||
|
assert result.started_at is not None
|
||||||
|
assert result.completed_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aexecute_collects_ai_messages(self, classes, base_config, mock_agent, msg):
|
||||||
|
"""Test that AI messages are collected during streaming."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
msg1 = msg.ai("First response", "msg-1")
|
||||||
|
msg2 = msg.ai("Second response", "msg-2")
|
||||||
|
|
||||||
|
chunk1 = {"messages": [msg.human("Task"), msg1]}
|
||||||
|
chunk2 = {"messages": [msg.human("Task"), msg1, msg2]}
|
||||||
|
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([chunk1, chunk2])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Task")
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert len(result.ai_messages) == 2
|
||||||
|
assert result.ai_messages[0]["id"] == "msg-1"
|
||||||
|
assert result.ai_messages[1]["id"] == "msg-2"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aexecute_handles_duplicate_messages(self, classes, base_config, mock_agent, msg):
|
||||||
|
"""Test that duplicate AI messages are not added."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
|
||||||
|
msg1 = msg.ai("Response", "msg-1")
|
||||||
|
|
||||||
|
# Same message appears in multiple chunks
|
||||||
|
chunk1 = {"messages": [msg.human("Task"), msg1]}
|
||||||
|
chunk2 = {"messages": [msg.human("Task"), msg1]}
|
||||||
|
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([chunk1, chunk2])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Task")
|
||||||
|
|
||||||
|
assert len(result.ai_messages) == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aexecute_handles_list_content(self, classes, base_config, mock_agent, msg):
|
||||||
|
"""Test handling of list-type content in AIMessage."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
final_message = msg.ai([{"text": "Part 1"}, {"text": "Part 2"}])
|
||||||
|
final_state = {
|
||||||
|
"messages": [
|
||||||
|
msg.human("Task"),
|
||||||
|
final_message,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Task")
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert "Part 1" in result.result
|
||||||
|
assert "Part 2" in result.result
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aexecute_handles_agent_exception(self, classes, base_config, mock_agent):
|
||||||
|
"""Test that exceptions during execution are caught and returned as FAILED."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
mock_agent.astream.side_effect = Exception("Agent error")
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Task")
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.FAILED
|
||||||
|
assert "Agent error" in result.error
|
||||||
|
assert result.completed_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aexecute_no_final_state(self, classes, base_config, mock_agent):
|
||||||
|
"""Test handling when no final state is returned."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Task")
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert result.result == "No response generated"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aexecute_no_ai_message_in_state(self, classes, base_config, mock_agent, msg):
|
||||||
|
"""Test fallback when no AIMessage found in final state."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
final_state = {"messages": [msg.human("Task")]}
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Task")
|
||||||
|
|
||||||
|
# Should fallback to string representation of last message
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert "Task" in result.result
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Sync Execution Path Tests
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncExecutionPath:
|
||||||
|
"""Test execute() synchronous execution path with asyncio.run()."""
|
||||||
|
|
||||||
|
def test_execute_runs_async_in_event_loop(self, classes, base_config, mock_agent, msg):
|
||||||
|
"""Test that execute() runs _aexecute() in a new event loop via asyncio.run()."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
final_message = msg.ai("Sync result", "msg-1")
|
||||||
|
final_state = {
|
||||||
|
"messages": [
|
||||||
|
msg.human("Task"),
|
||||||
|
final_message,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = executor.execute("Task")
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert result.result == "Sync result"
|
||||||
|
|
||||||
|
def test_execute_in_thread_pool_context(self, classes, base_config, msg):
|
||||||
|
"""Test that execute() works correctly when called from a thread pool.
|
||||||
|
|
||||||
|
This simulates the real-world usage where execute() is called from
|
||||||
|
_execution_pool in execute_async().
|
||||||
|
"""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
final_message = msg.ai("Thread pool result", "msg-1")
|
||||||
|
final_state = {
|
||||||
|
"messages": [
|
||||||
|
msg.human("Task"),
|
||||||
|
final_message,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([final_state])
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
return executor.execute("Task")
|
||||||
|
|
||||||
|
# Execute in thread pool (simulating _execution_pool usage)
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||||
|
future = pool.submit(run_in_thread)
|
||||||
|
result = future.result(timeout=5)
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert result.result == "Thread pool result"
|
||||||
|
|
||||||
|
def test_execute_handles_asyncio_run_failure(self, classes, base_config):
|
||||||
|
"""Test handling when asyncio.run() itself fails."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_aexecute") as mock_aexecute:
|
||||||
|
mock_aexecute.side_effect = Exception("Asyncio run error")
|
||||||
|
|
||||||
|
result = executor.execute("Task")
|
||||||
|
|
||||||
|
assert result.status == SubagentStatus.FAILED
|
||||||
|
assert "Asyncio run error" in result.error
|
||||||
|
assert result.completed_at is not None
|
||||||
|
|
||||||
|
def test_execute_with_result_holder(self, classes, base_config, mock_agent, msg):
|
||||||
|
"""Test execute() updates provided result_holder in real-time."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentResult = classes["SubagentResult"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
msg1 = msg.ai("Step 1", "msg-1")
|
||||||
|
chunk1 = {"messages": [msg.human("Task"), msg1]}
|
||||||
|
|
||||||
|
mock_agent.astream = lambda *args, **kwargs: async_iterator([chunk1])
|
||||||
|
|
||||||
|
# Pre-create result holder (as done in execute_async)
|
||||||
|
result_holder = SubagentResult(
|
||||||
|
task_id="predefined-id",
|
||||||
|
trace_id="test-trace",
|
||||||
|
status=SubagentStatus.RUNNING,
|
||||||
|
started_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = executor.execute("Task", result_holder=result_holder)
|
||||||
|
|
||||||
|
# Should be the same object
|
||||||
|
assert result is result_holder
|
||||||
|
assert result.task_id == "predefined-id"
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Async Tool Support Tests (MCP Tools)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncToolSupport:
|
||||||
|
"""Test that async-only tools (like MCP tools) work correctly."""
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_async_tool_called_in_astream(self, classes, base_config, msg):
|
||||||
|
"""Test that async tools are properly awaited in astream.
|
||||||
|
|
||||||
|
This verifies the fix for: async MCP tools not being executed properly
|
||||||
|
because they were being called synchronously.
|
||||||
|
"""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
async_tool_calls = []
|
||||||
|
|
||||||
|
async def mock_async_tool(*args, **kwargs):
|
||||||
|
async_tool_calls.append("called")
|
||||||
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
|
return {"result": "async tool result"}
|
||||||
|
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
|
||||||
|
# Simulate agent that calls async tools during streaming
|
||||||
|
async def mock_astream(*args, **kwargs):
|
||||||
|
await mock_async_tool()
|
||||||
|
yield {
|
||||||
|
"messages": [
|
||||||
|
msg.human("Task"),
|
||||||
|
msg.ai("Done", "msg-1"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_agent.astream = mock_astream
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = await executor._aexecute("Task")
|
||||||
|
|
||||||
|
assert len(async_tool_calls) == 1
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
|
||||||
|
def test_sync_execute_with_async_tools(self, classes, base_config, msg):
|
||||||
|
"""Test that sync execute() properly runs async tools via asyncio.run()."""
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
async_tool_calls = []
|
||||||
|
|
||||||
|
async def mock_async_tool():
|
||||||
|
async_tool_calls.append("called")
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return {"result": "async result"}
|
||||||
|
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
|
||||||
|
async def mock_astream(*args, **kwargs):
|
||||||
|
await mock_async_tool()
|
||||||
|
yield {
|
||||||
|
"messages": [
|
||||||
|
msg.human("Task"),
|
||||||
|
msg.ai("Done", "msg-1"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_agent.astream = mock_astream
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
result = executor.execute("Task")
|
||||||
|
|
||||||
|
assert len(async_tool_calls) == 1
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Thread Safety Tests
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadSafety:
|
||||||
|
"""Test thread safety of executor operations."""
|
||||||
|
|
||||||
|
def test_multiple_executors_in_parallel(self, classes, base_config, msg):
|
||||||
|
"""Test multiple executors running in parallel via thread pool."""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
SubagentExecutor = classes["SubagentExecutor"]
|
||||||
|
SubagentStatus = classes["SubagentStatus"]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def execute_task(task_id: int):
|
||||||
|
def make_astream(*args, **kwargs):
|
||||||
|
return async_iterator(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
msg.human(f"Task {task_id}"),
|
||||||
|
msg.ai(f"Result {task_id}", f"msg-{task_id}"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_agent = MagicMock()
|
||||||
|
mock_agent.astream = make_astream
|
||||||
|
|
||||||
|
executor = SubagentExecutor(
|
||||||
|
config=base_config,
|
||||||
|
tools=[],
|
||||||
|
thread_id=f"thread-{task_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(executor, "_create_agent", return_value=mock_agent):
|
||||||
|
return executor.execute(f"Task {task_id}")
|
||||||
|
|
||||||
|
# Execute multiple tasks in parallel
|
||||||
|
with ThreadPoolExecutor(max_workers=3) as pool:
|
||||||
|
futures = [pool.submit(execute_task, i) for i in range(5)]
|
||||||
|
for future in as_completed(futures):
|
||||||
|
results.append(future.result())
|
||||||
|
|
||||||
|
assert len(results) == 5
|
||||||
|
for result in results:
|
||||||
|
assert result.status == SubagentStatus.COMPLETED
|
||||||
|
assert "Result" in result.result
|
||||||
Reference in New Issue
Block a user