fix(mcp): implement sync invocation wrapper for async MCP tools (#1287)

* fix(mcp): implement sync invocation wrapper for async MCP tools

Since DeerFlowClient streams synchronously, invoking async-only MCP tools
(loaded via langchain-mcp-adapters) resulted in a NotImplementedError.
This commit bridges the sync/async gap by dynamically injecting a `func`
wrapper into `StructuredTool` instances that only have a `coroutine`.

Key changes:
- Added `sync_wrapper` in `get_mcp_tools` to execute async tool calls.
- Handled nested event loops by delegating to a global `ThreadPoolExecutor`
  when an event loop is already running, avoiding `RuntimeError`.
- Added detailed error logging within the wrapper for better transparency.
- Added comprehensive test coverage in `test_mcp_sync_wrapper.py` verifying
  tool patching, event loop behavior, and exception propagation.

* refactor(mcp): extract sync wrapper to module level and fix test mocks

Addressed PR review comments:
- Extracted _make_sync_tool_wrapper to module level to avoid nested func definitions.
- Refactored tests to use the actual production helper instead of duplicating logic.
- Fixed AsyncMock patching for awaited dependencies in tests.
- Added atexit hook for graceful thread pool shutdown.
- Fixed PEP8 blank line formatting in tests.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
knukn
2026-03-24 22:38:01 +08:00
committed by GitHub
parent 6bf526748d
commit a9940c391c
2 changed files with 128 additions and 0 deletions

View File

@@ -1,6 +1,11 @@
"""Load MCP tools using langchain-mcp-adapters."""
import asyncio
import atexit
import concurrent.futures
import logging
from collections.abc import Callable
from typing import Any
from langchain_core.tools import BaseTool
@@ -10,6 +15,43 @@ from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_h
logger = logging.getLogger(__name__)
# Global thread pool for sync tool invocation in async environments
_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool")
# Register shutdown hook for the global executor
atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False))
def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]:
"""Build a synchronous wrapper for an asynchronous tool coroutine.
Args:
coro: The tool's asynchronous coroutine.
tool_name: Name of the tool (for logging).
Returns:
A synchronous function that correctly handles nested event loops.
"""
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
try:
if loop is not None and loop.is_running():
# Use global executor to avoid nested loop issues and improve performance
future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs))
return future.result()
else:
return asyncio.run(coro(*args, **kwargs))
except Exception as e:
logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True)
raise
return sync_wrapper
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
@@ -58,6 +100,11 @@ async def get_mcp_tools() -> list[BaseTool]:
# Get all tools from all servers
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools