fix(mcp): implement sync invocation wrapper for async MCP tools (#1287)

* fix(mcp): implement sync invocation wrapper for async MCP tools

Since DeerFlowClient streams synchronously, invoking async-only MCP tools
(loaded via langchain-mcp-adapters) resulted in a NotImplementedError.
This commit bridges the sync/async gap by dynamically injecting a `func`
wrapper into `StructuredTool` instances that only have a `coroutine`.

Key changes:
- Added `sync_wrapper` in `get_mcp_tools` to execute async tool calls.
- Handled nested event loops by delegating to a global `ThreadPoolExecutor`
  when an event loop is already running, avoiding `RuntimeError`.
- Added detailed error logging within the wrapper for better transparency.
- Added comprehensive test coverage in `test_mcp_sync_wrapper.py` verifying
  tool patching, event loop behavior, and exception propagation.

* refactor(mcp): extract sync wrapper to module level and fix test mocks

Addressed PR review comments:
- Extracted _make_sync_tool_wrapper to module level to avoid nested func definitions.
- Refactored tests to use the actual production helper instead of duplicating logic.
- Fixed AsyncMock patching for awaited dependencies in tests.
- Added atexit hook for graceful thread pool shutdown.
- Fixed PEP8 blank line formatting in tests.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
knukn
2026-03-24 22:38:01 +08:00
committed by GitHub
parent 6bf526748d
commit a9940c391c
2 changed files with 128 additions and 0 deletions

View File

@@ -1,6 +1,11 @@
"""Load MCP tools using langchain-mcp-adapters."""
import asyncio
import atexit
import concurrent.futures
import logging
from collections.abc import Callable
from typing import Any
from langchain_core.tools import BaseTool
@@ -10,6 +15,43 @@ from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_h
logger = logging.getLogger(__name__)
# Global thread pool for sync tool invocation in async environments
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
# Register shutdown hook for the global executor
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine.
Args:
coro: The tool's asynchronous coroutine.
tool_name: Name of the tool (for logging).
Returns:
A synchronous function that correctly handles nested event loops.
"""
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is not None and loop.is_running():
# Use global executor to avoid nested loop issues and improve performance
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
return future.result()
else:
return asyncio.run(coro(*args, **kwargs))
except Exception as e:
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
raise
return sync_wrapper
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
@@ -58,6 +100,11 @@ async def get_mcp_tools() -> list[BaseTool]:
# Get all tools from all servers
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools

View File

@@ -0,0 +1,81 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_sync_tool_wrapper, get_mcp_tools
class MockArgs(BaseModel):
x: int = Field(..., description="test param")
def test_mcp_tool_sync_wrapper_generation():
"""Test that get_mcp_tools correctly adds a sync func to async-only tools."""
async def mock_coro(x: int):
return f"result: {x}"
mock_tool = StructuredTool(
name="test_tool",
description="test description",
args_schema=MockArgs,
func=None, # Sync func is missing
coroutine=mock_coro
)
mock_client_instance = MagicMock()
# Use AsyncMock for get_tools as it's awaited (Fix for Comment 5)
mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool])
with patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance), \
patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"), \
patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}), \
patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}):
# Run the async function manually with asyncio.run
tools = asyncio.run(get_mcp_tools())
assert len(tools) == 1
patched_tool = tools[0]
# Verify func is now populated
assert patched_tool.func is not None
# Verify it works (sync call)
result = patched_tool.func(x=42)
assert result == "result: 42"
def test_mcp_tool_sync_wrapper_in_running_loop():
"""Test the actual helper function from production code (Fix for Comment 1 & 3)."""
async def mock_coro(x: int):
await asyncio.sleep(0.01)
return f"async_result: {x}"
# Test the real helper function exported from deerflow.mcp.tools
sync_func = _make_sync_tool_wrapper(mock_coro, "test_tool")
async def run_in_loop():
# This call should succeed due to ThreadPoolExecutor in the real helper
return sync_func(x=100)
# We run the async function that calls the sync func
result = asyncio.run(run_in_loop())
assert result == "async_result: 100"
def test_mcp_tool_sync_wrapper_exception_logging():
"""Test the actual helper's error logging (Fix for Comment 3)."""
async def error_coro():
raise ValueError("Tool failure")
sync_func = _make_sync_tool_wrapper(error_coro, "error_tool")
with patch("deerflow.mcp.tools.logger.error") as mock_log_error:
with pytest.raises(ValueError, match="Tool failure"):
sync_func()
mock_log_error.assert_called_once()
# Verify the tool name is in the log message
assert "error_tool" in mock_log_error.call_args[0][0]