diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index 02d7cb1..70c5e69 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -1,6 +1,11 @@ """Load MCP tools using langchain-mcp-adapters.""" +import asyncio +import atexit +import concurrent.futures import logging +from collections.abc import Callable +from typing import Any from langchain_core.tools import BaseTool @@ -10,6 +15,43 @@ from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_h logger = logging.getLogger(__name__) +# Global thread pool for sync tool invocation in async environments +_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool") + +# Register shutdown hook for the global executor +atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False)) + + +def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: + """Build a synchronous wrapper for an asynchronous tool coroutine. + + Args: + coro: The tool's asynchronous coroutine. + tool_name: Name of the tool (for logging). + + Returns: + A synchronous function that correctly handles nested event loops. + """ + + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + try: + if loop is not None and loop.is_running(): + # Use global executor to avoid nested loop issues and improve performance + future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs)) + return future.result() + else: + return asyncio.run(coro(*args, **kwargs)) + except Exception as e: + logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True) + raise + + return sync_wrapper + async def get_mcp_tools() -> list[BaseTool]: """Get all tools from enabled MCP servers. @@ -58,6 +100,11 @@ async def get_mcp_tools() -> list[BaseTool]: # Get all tools from all servers tools = await client.get_tools() logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers") + + # Patch tools to support sync invocation, as deerflow client streams synchronously + for tool in tools: + if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None: + tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name) return tools diff --git a/backend/tests/test_mcp_sync_wrapper.py b/backend/tests/test_mcp_sync_wrapper.py new file mode 100644 index 0000000..ab636b5 --- /dev/null +++ b/backend/tests/test_mcp_sync_wrapper.py @@ -0,0 +1,81 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langchain_core.tools import StructuredTool +from pydantic import BaseModel, Field + +from deerflow.mcp.tools import _make_sync_tool_wrapper, get_mcp_tools + + +class MockArgs(BaseModel): + x: int = Field(..., description="test param") + + +def test_mcp_tool_sync_wrapper_generation(): + """Test that get_mcp_tools correctly adds a sync func to async-only tools.""" + async def mock_coro(x: int): + return f"result: {x}" + + mock_tool = StructuredTool( + name="test_tool", + description="test description", + args_schema=MockArgs, + func=None, # Sync func is missing + coroutine=mock_coro + ) + + mock_client_instance = MagicMock() + # Use AsyncMock for get_tools as it's awaited (Fix for Comment 5) + mock_client_instance.get_tools = AsyncMock(return_value=[mock_tool]) + + with patch("langchain_mcp_adapters.client.MultiServerMCPClient", return_value=mock_client_instance), \ + patch("deerflow.config.extensions_config.ExtensionsConfig.from_file"), \ + patch("deerflow.mcp.tools.build_servers_config", return_value={"test-server": {}}), \ + patch("deerflow.mcp.tools.get_initial_oauth_headers", new_callable=AsyncMock, return_value={}): + + # Run the async function manually with asyncio.run + tools = asyncio.run(get_mcp_tools()) + + assert len(tools) == 1 + patched_tool = tools[0] + + # Verify func is now populated + assert patched_tool.func is not None + + # Verify it works (sync call) + result = patched_tool.func(x=42) + assert result == "result: 42" + + +def test_mcp_tool_sync_wrapper_in_running_loop(): + """Test the actual helper function from production code (Fix for Comment 1 & 3).""" + async def mock_coro(x: int): + await asyncio.sleep(0.01) + return f"async_result: {x}" + + # Test the real helper function exported from deerflow.mcp.tools + sync_func = _make_sync_tool_wrapper(mock_coro, "test_tool") + + async def run_in_loop(): + # This call should succeed due to ThreadPoolExecutor in the real helper + return sync_func(x=100) + + # We run the async function that calls the sync func + result = asyncio.run(run_in_loop()) + assert result == "async_result: 100" + + +def test_mcp_tool_sync_wrapper_exception_logging(): + """Test the actual helper's error logging (Fix for Comment 3).""" + async def error_coro(): + raise ValueError("Tool failure") + + sync_func = _make_sync_tool_wrapper(error_coro, "error_tool") + + with patch("deerflow.mcp.tools.logger.error") as mock_log_error: + with pytest.raises(ValueError, match="Tool failure"): + sync_func() + mock_log_error.assert_called_once() + # Verify the tool name is in the log message + assert "error_tool" in mock_log_error.call_args[0][0]