mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import logging
|
|
from datetime import timedelta
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from fastapi import HTTPException
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
from mcp.client.sse import sse_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def _get_tools_from_client_session(
|
|
client_context_manager: Any, timeout_seconds: int = 10
|
|
) -> List:
|
|
"""
|
|
Helper function to get tools from a client session.
|
|
|
|
Args:
|
|
client_context_manager: A context manager that returns (read, write) functions
|
|
timeout_seconds: Timeout in seconds for the read operation
|
|
|
|
Returns:
|
|
List of available tools from the MCP server
|
|
|
|
Raises:
|
|
Exception: If there's an error during the process
|
|
"""
|
|
async with client_context_manager as (read, write):
|
|
async with ClientSession(
|
|
read, write, read_timeout_seconds=timedelta(seconds=timeout_seconds)
|
|
) as session:
|
|
# Initialize the connection
|
|
await session.initialize()
|
|
# List available tools
|
|
listed_tools = await session.list_tools()
|
|
return listed_tools.tools
|
|
|
|
|
|
async def load_mcp_tools(
|
|
server_type: str,
|
|
command: Optional[str] = None,
|
|
args: Optional[List[str]] = None,
|
|
url: Optional[str] = None,
|
|
env: Optional[Dict[str, str]] = None,
|
|
timeout_seconds: int = 60, # Longer default timeout for first-time executions
|
|
) -> List:
|
|
"""
|
|
Load tools from an MCP server.
|
|
|
|
Args:
|
|
server_type: The type of MCP server connection (stdio or sse)
|
|
command: The command to execute (for stdio type)
|
|
args: Command arguments (for stdio type)
|
|
url: The URL of the SSE server (for sse type)
|
|
env: Environment variables
|
|
timeout_seconds: Timeout in seconds (default: 60 for first-time executions)
|
|
|
|
Returns:
|
|
List of available tools from the MCP server
|
|
|
|
Raises:
|
|
HTTPException: If there's an error loading the tools
|
|
"""
|
|
try:
|
|
if server_type == "stdio":
|
|
if not command:
|
|
raise HTTPException(
|
|
status_code=400, detail="Command is required for stdio type"
|
|
)
|
|
|
|
server_params = StdioServerParameters(
|
|
command=command, # Executable
|
|
args=args, # Optional command line arguments
|
|
env=env, # Optional environment variables
|
|
)
|
|
|
|
return await _get_tools_from_client_session(
|
|
stdio_client(server_params), timeout_seconds
|
|
)
|
|
|
|
elif server_type == "sse":
|
|
if not url:
|
|
raise HTTPException(
|
|
status_code=400, detail="URL is required for sse type"
|
|
)
|
|
|
|
return await _get_tools_from_client_session(
|
|
sse_client(url=url), timeout_seconds
|
|
)
|
|
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Unsupported server type: {server_type}"
|
|
)
|
|
|
|
except Exception as e:
|
|
if not isinstance(e, HTTPException):
|
|
logger.exception(f"Error loading MCP tools: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
raise
|