mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
fix(mcp): implement sync invocation wrapper for async MCP tools (#1287)
* fix(mcp): implement sync invocation wrapper for async MCP tools Since DeerFlowClient streams synchronously, invoking async-only MCP tools (loaded via langchain-mcp-adapters) resulted in a NotImplementedError. This commit bridges the sync/async gap by dynamically injecting a `func` wrapper into `StructuredTool` instances that only have a `coroutine`. Key changes: - Added `sync_wrapper` in `get_mcp_tools` to execute async tool calls. - Handled nested event loops by delegating to a global `ThreadPoolExecutor` when an event loop is already running, avoiding `RuntimeError`. - Added detailed error logging within the wrapper for better transparency. - Added comprehensive test coverage in `test_mcp_sync_wrapper.py` verifying tool patching, event loop behavior, and exception propagation. * refactor(mcp): extract sync wrapper to module level and fix test mocks Addressed PR review comments: - Extracted _make_sync_tool_wrapper to module level to avoid nested func definitions. - Refactored tests to use the actual production helper instead of duplicating logic. - Fixed AsyncMock patching for awaited dependencies in tests. - Added atexit hook for graceful thread pool shutdown. - Fixed PEP8 blank line formatting in tests. --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,6 +1,11 @@
|
|||||||
"""Load MCP tools using langchain-mcp-adapters."""
|
"""Load MCP tools using langchain-mcp-adapters."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import concurrent.futures
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
@@ -10,6 +15,43 @@ from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_h
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global thread pool for sync tool invocation in async environments
|
||||||
|
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
|
||||||
|
|
||||||
|
# Register shutdown hook for the global executor
|
||||||
|
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
|
||||||
|
"""Build a synchronous wrapper for an asynchronous tool coroutine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coro: The tool's asynchronous coroutine.
|
||||||
|
tool_name: Name of the tool (for logging).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A synchronous function that correctly handles nested event loops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if loop is not None and loop.is_running():
|
||||||
|
# Use global executor to avoid nested loop issues and improve performance
|
||||||
|
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
|
||||||
|
return future.result()
|
||||||
|
else:
|
||||||
|
return asyncio.run(coro(*args, **kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
async def get_mcp_tools() -> list[BaseTool]:
|
async def get_mcp_tools() -> list[BaseTool]:
|
||||||
"""Get all tools from enabled MCP servers.
|
"""Get all tools from enabled MCP servers.
|
||||||
@@ -58,6 +100,11 @@ async def get_mcp_tools() -> list[BaseTool]:
|
|||||||
# Get all tools from all servers
|
# Get all tools from all servers
|
||||||
tools = await client.get_tools()
|
tools = await client.get_tools()
|
||||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||||
|
|
||||||
|
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||||
|
for tool in tools:
|
||||||
|
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||||
|
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|||||||
81
backend/tests/test_mcp_sync_wrapper.py
Normal file
81
backend/tests/test_mcp_sync_wrapper.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from deerflow.mcp.tools import _make_sync_tool_wrapper, get_mcp_tools
|
||||||
|
|
||||||
|
|
||||||
|
class MockArgs(BaseModel):
|
||||||
|
x: int = Field(..., description="test param")
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_tool_sync_wrapper_generation():
|
||||||
|
"""Test that get_mcp_tools correctly adds a sync func to async-only tools."""
|
||||||
|
async def mock_coro(x: int):
|
||||||
|
return f"result: {x}"
|
||||||
|
|
||||||
|
mock_tool = StructuredTool(
|
||||||
|
name="test_tool",
|
||||||
|
description="test description",
|
||||||
|
args_schema=MockArgs,
|
||||||
|
func=None, # Sync func is missing
|
||||||
|
coroutine=mock_coro
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client_instance = MagicMock()
|
||||||
|
# Use AsyncMock for get_tools as it's awaited (Fix for Comment 5)
|
||||||
|
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])
|
||||||
|
|
||||||
|
with patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance), \
|
||||||
|
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"), \
|
||||||
|
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}), \
|
||||||
|
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}):
|
||||||
|
|
||||||
|
# Run the async function manually with asyncio.run
|
||||||
|
tools = asyncio.run(get_mcp_tools())
|
||||||
|
|
||||||
|
assert len(tools) == 1
|
||||||
|
patched_tool = tools[0]
|
||||||
|
|
||||||
|
# Verify func is now populated
|
||||||
|
assert patched_tool.func is not None
|
||||||
|
|
||||||
|
# Verify it works (sync call)
|
||||||
|
result = patched_tool.func(x=42)
|
||||||
|
assert result == "result: 42"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_tool_sync_wrapper_in_running_loop():
|
||||||
|
"""Test the actual helper function from production code (Fix for Comment 1 & 3)."""
|
||||||
|
async def mock_coro(x: int):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return f"async_result: {x}"
|
||||||
|
|
||||||
|
# Test the real helper function exported from deerflow.mcp.tools
|
||||||
|
sync_func = _make_sync_tool_wrapper(mock_coro, "test_tool")
|
||||||
|
|
||||||
|
async def run_in_loop():
|
||||||
|
# This call should succeed due to ThreadPoolExecutor in the real helper
|
||||||
|
return sync_func(x=100)
|
||||||
|
|
||||||
|
# We run the async function that calls the sync func
|
||||||
|
result = asyncio.run(run_in_loop())
|
||||||
|
assert result == "async_result: 100"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mcp_tool_sync_wrapper_exception_logging():
|
||||||
|
"""Test the actual helper's error logging (Fix for Comment 3)."""
|
||||||
|
async def error_coro():
|
||||||
|
raise ValueError("Tool failure")
|
||||||
|
|
||||||
|
sync_func = _make_sync_tool_wrapper(error_coro, "error_tool")
|
||||||
|
|
||||||
|
with patch("deerflow.mcp.tools.logger.error") as mock_log_error:
|
||||||
|
with pytest.raises(ValueError, match="Tool failure"):
|
||||||
|
sync_func()
|
||||||
|
mock_log_error.assert_called_once()
|
||||||
|
# Verify the tool name is in the log message
|
||||||
|
assert "error_tool" in mock_log_error.call_args[0][0]
|
||||||
Reference in New Issue
Block a user