mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-13 02:24:44 +08:00
feat: implement tools loading api
This commit is contained in:
@@ -13,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage
|
||||
from langgraph.types import Command
|
||||
from mcp import ClientSession
|
||||
|
||||
from src.graph.builder import build_graph_with_memory
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
@@ -24,6 +25,8 @@ from src.server.chat_request import (
|
||||
GeneratePPTRequest,
|
||||
TTSRequest,
|
||||
)
|
||||
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
||||
from src.server.mcp_utils import load_mcp_tools
|
||||
from src.tools import VolcengineTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -244,3 +247,34 @@ async def generate_ppt(request: GeneratePPTRequest):
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred during ppt generation: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
||||
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
||||
"""Get information about an MCP server."""
|
||||
try:
|
||||
# Load tools from the MCP server using the utility function
|
||||
tools = await load_mcp_tools(
|
||||
server_type=request.type,
|
||||
command=request.command,
|
||||
args=request.args,
|
||||
url=request.url,
|
||||
env=request.env,
|
||||
)
|
||||
|
||||
# Create the response with tools
|
||||
response = MCPServerMetadataResponse(
|
||||
type=request.type,
|
||||
command=request.command,
|
||||
args=request.args,
|
||||
url=request.url,
|
||||
env=request.env,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
if not isinstance(e, HTTPException):
|
||||
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise
|
||||
|
||||
45
src/server/mcp_request.py
Normal file
45
src/server/mcp_request.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MCPServerMetadataRequest(BaseModel):
|
||||
"""Request model for MCP server metadata."""
|
||||
|
||||
type: str = Field(
|
||||
..., description="The type of MCP server connection (stdio or sse)"
|
||||
)
|
||||
command: Optional[str] = Field(
|
||||
None, description="The command to execute (for stdio type)"
|
||||
)
|
||||
args: Optional[List[str]] = Field(
|
||||
None, description="Command arguments (for stdio type)"
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
None, description="The URL of the SSE server (for sse type)"
|
||||
)
|
||||
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
|
||||
|
||||
|
||||
class MCPServerMetadataResponse(BaseModel):
|
||||
"""Response model for MCP server metadata."""
|
||||
|
||||
type: str = Field(
|
||||
..., description="The type of MCP server connection (stdio or sse)"
|
||||
)
|
||||
command: Optional[str] = Field(
|
||||
None, description="The command to execute (for stdio type)"
|
||||
)
|
||||
args: Optional[List[str]] = Field(
|
||||
None, description="Command arguments (for stdio type)"
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
None, description="The URL of the SSE server (for sse type)"
|
||||
)
|
||||
env: Optional[Dict[str, str]] = Field(None, description="Environment variables")
|
||||
tools: List = Field(
|
||||
default_factory=list, description="Available tools from the MCP server"
|
||||
)
|
||||
95
src/server/mcp_utils.py
Normal file
95
src/server/mcp_utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _get_tools_from_client_session(client_context_manager: Any) -> List:
|
||||
"""
|
||||
Helper function to get tools from a client session.
|
||||
|
||||
Args:
|
||||
client_context_manager: A context manager that returns (read, write) functions
|
||||
|
||||
Returns:
|
||||
List of available tools from the MCP server
|
||||
|
||||
Raises:
|
||||
Exception: If there's an error during the process
|
||||
"""
|
||||
async with client_context_manager as (read, write):
|
||||
async with ClientSession(
|
||||
read, write, read_timeout_seconds=timedelta(seconds=10)
|
||||
) as session:
|
||||
# Initialize the connection
|
||||
await session.initialize()
|
||||
# List available tools
|
||||
listed_tools = await session.list_tools()
|
||||
return listed_tools.tools
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
server_type: str,
|
||||
command: Optional[str] = None,
|
||||
args: Optional[List[str]] = None,
|
||||
url: Optional[str] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
) -> List:
|
||||
"""
|
||||
Load tools from an MCP server.
|
||||
|
||||
Args:
|
||||
server_type: The type of MCP server connection (stdio or sse)
|
||||
command: The command to execute (for stdio type)
|
||||
args: Command arguments (for stdio type)
|
||||
url: The URL of the SSE server (for sse type)
|
||||
env: Environment variables
|
||||
|
||||
Returns:
|
||||
List of available tools from the MCP server
|
||||
|
||||
Raises:
|
||||
HTTPException: If there's an error loading the tools
|
||||
"""
|
||||
try:
|
||||
if server_type == "stdio":
|
||||
if not command:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Command is required for stdio type"
|
||||
)
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=command, # Executable
|
||||
args=args, # Optional command line arguments
|
||||
env=env, # Optional environment variables
|
||||
)
|
||||
|
||||
return await _get_tools_from_client_session(stdio_client(server_params))
|
||||
|
||||
elif server_type == "sse":
|
||||
if not url:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="URL is required for sse type"
|
||||
)
|
||||
|
||||
return await _get_tools_from_client_session(sse_client(url=url))
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Unsupported server type: {server_type}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, HTTPException):
|
||||
logger.exception(f"Error loading MCP tools: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise
|
||||
Reference in New Issue
Block a user