From dae036f58311b2c6fc00c91969ea007bc133611f Mon Sep 17 00:00:00 2001 From: He Tao Date: Wed, 23 Apr 2025 14:38:04 +0800 Subject: [PATCH] feat: implement tools loading api --- Makefile | 2 +- docs/mcp_integrations.md | 6 +- pyproject.toml | 1 + src/server/app.py | 34 +++++++++ src/server/mcp_request.py | 45 +++++++++++ src/server/mcp_utils.py | 95 +++++++++++++++++++++++ uv.lock | 153 ++++++++++++++++++++++---------------- 7 files changed, 265 insertions(+), 71 deletions(-) create mode 100644 src/server/mcp_request.py create mode 100644 src/server/mcp_utils.py diff --git a/Makefile b/Makefile index c3ccf3e..ae87668 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ lint: uv run black --check . serve: - uv run server.py + uv run server.py --reload test: uv run pytest tests/ diff --git a/docs/mcp_integrations.md b/docs/mcp_integrations.md index 14cd5be..59213d6 100644 --- a/docs/mcp_integrations.md +++ b/docs/mcp_integrations.md @@ -11,10 +11,8 @@ For stdio type: { "type": "stdio", "command": "npx", - "args": ["@agentdeskai/browser-tools-mcp@1.2.0"] - "env": { - "MCP_SERVER_ID": "mcp-github-trending" - } + "args": ["-y", "tavily-mcp@0.1.3"], + "env": {"TAVILY_API_KEY": "tvly-dev-xxx"} } ``` diff --git a/pyproject.toml b/pyproject.toml index 3525990..8ede212 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "duckduckgo-search>=8.0.0", "inquirerpy>=0.3.4", "arxiv>=2.2.0", + "mcp>=1.6.0", ] [project.optional-dependencies] diff --git a/src/server/app.py b/src/server/app.py index b53e952..1482915 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -13,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, StreamingResponse from langchain_core.messages import AIMessageChunk, ToolMessage from langgraph.types import Command +from mcp import ClientSession from src.graph.builder import build_graph_with_memory from src.podcast.graph.builder import build_graph as build_podcast_graph @@ -24,6 +25,8 @@ from src.server.chat_request import ( GeneratePPTRequest, TTSRequest, ) +from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse +from src.server.mcp_utils import load_mcp_tools from src.tools import VolcengineTTS logger = logging.getLogger(__name__) @@ -244,3 +247,34 @@ async def generate_ppt(request: GeneratePPTRequest): except Exception as e: logger.exception(f"Error occurred during ppt generation: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse) +async def mcp_server_metadata(request: MCPServerMetadataRequest): + """Get information about an MCP server.""" + try: + # Load tools from the MCP server using the utility function + tools = await load_mcp_tools( + server_type=request.type, + command=request.command, + args=request.args, + url=request.url, + env=request.env, + ) + + # Create the response with tools + response = MCPServerMetadataResponse( + type=request.type, + command=request.command, + args=request.args, + url=request.url, + env=request.env, + tools=tools, + ) + + return response + except Exception as e: + if not isinstance(e, HTTPException): + logger.exception(f"Error in MCP server metadata endpoint: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + raise diff --git a/src/server/mcp_request.py b/src/server/mcp_request.py new file mode 100644 index 0000000..e82315f --- /dev/null +++ b/src/server/mcp_request.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from typing import Dict, List, Optional + +from pydantic import BaseModel, Field + + +class MCPServerMetadataRequest(BaseModel): + """Request model for MCP server metadata.""" + + type: str = Field( + ..., description="The type of MCP server connection (stdio or sse)" + ) + command: Optional[str] = Field( + None, description="The command to execute (for stdio type)" + ) + args: Optional[List[str]] = Field( + None, description="Command arguments (for stdio type)" + ) + url: Optional[str] = Field( + None, description="The URL of the SSE server (for sse type)" + ) + env: Optional[Dict[str, str]] = Field(None, description="Environment variables") + + +class MCPServerMetadataResponse(BaseModel): + """Response model for MCP server metadata.""" + + type: str = Field( + ..., description="The type of MCP server connection (stdio or sse)" + ) + command: Optional[str] = Field( + None, description="The command to execute (for stdio type)" + ) + args: Optional[List[str]] = Field( + None, description="Command arguments (for stdio type)" + ) + url: Optional[str] = Field( + None, description="The URL of the SSE server (for sse type)" + ) + env: Optional[Dict[str, str]] = Field(None, description="Environment variables") + tools: List = Field( + default_factory=list, description="Available tools from the MCP server" + ) diff --git a/src/server/mcp_utils.py b/src/server/mcp_utils.py new file mode 100644 index 0000000..62a8243 --- /dev/null +++ b/src/server/mcp_utils.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import logging +from datetime import timedelta +from typing import Any, Dict, List, Optional, Tuple + +from fastapi import HTTPException +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.client.sse import sse_client + +logger = logging.getLogger(__name__) + + +async def _get_tools_from_client_session(client_context_manager: Any) -> List: + """ + Helper function to get tools from a client session. + + Args: + client_context_manager: A context manager that returns (read, write) functions + + Returns: + List of available tools from the MCP server + + Raises: + Exception: If there's an error during the process + """ + async with client_context_manager as (read, write): + async with ClientSession( + read, write, read_timeout_seconds=timedelta(seconds=10) + ) as session: + # Initialize the connection + await session.initialize() + # List available tools + listed_tools = await session.list_tools() + return listed_tools.tools + + +async def load_mcp_tools( + server_type: str, + command: Optional[str] = None, + args: Optional[List[str]] = None, + url: Optional[str] = None, + env: Optional[Dict[str, str]] = None, +) -> List: + """ + Load tools from an MCP server. + + Args: + server_type: The type of MCP server connection (stdio or sse) + command: The command to execute (for stdio type) + args: Command arguments (for stdio type) + url: The URL of the SSE server (for sse type) + env: Environment variables + + Returns: + List of available tools from the MCP server + + Raises: + HTTPException: If there's an error loading the tools + """ + try: + if server_type == "stdio": + if not command: + raise HTTPException( + status_code=400, detail="Command is required for stdio type" + ) + + server_params = StdioServerParameters( + command=command, # Executable + args=args, # Optional command line arguments + env=env, # Optional environment variables + ) + + return await _get_tools_from_client_session(stdio_client(server_params)) + + elif server_type == "sse": + if not url: + raise HTTPException( + status_code=400, detail="URL is required for sse type" + ) + + return await _get_tools_from_client_session(sse_client(url=url)) + + else: + raise HTTPException( + status_code=400, detail=f"Unsupported server type: {server_type}" + ) + + except Exception as e: + if not isinstance(e, HTTPException): + logger.exception(f"Error loading MCP tools: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + raise diff --git a/uv.lock b/uv.lock index c2924c2..cf247f5 100644 --- a/uv.lock +++ b/uv.lock @@ -309,6 +309,74 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686 }, ] +[[package]] +name = "deer-flow" +version = "0.1.0" +source = { editable = "." } +dependencies = [ + { name = "arxiv" }, + { name = "duckduckgo-search" }, + { name = "fastapi" }, + { name = "httpx" }, + { name = "inquirerpy" }, + { name = "jinja2" }, + { name = "json-repair" }, + { name = "langchain-community" }, + { name = "langchain-experimental" }, + { name = "langchain-openai" }, + { name = "langgraph" }, + { name = "litellm" }, + { name = "markdownify" }, + { name = "mcp" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "python-dotenv" }, + { name = "readabilipy" }, + { name = "socksio" }, + { name = "sse-starlette" }, + { name = "uvicorn" }, + { name = "yfinance" }, +] + +[package.optional-dependencies] +dev = [ + { name = "black" }, +] +test = [ + { name = "pytest" }, + { name = "pytest-cov" }, +] + +[package.metadata] +requires-dist = [ + { name = "arxiv", specifier = ">=2.2.0" }, + { name = "black", marker = "extra == 'dev'", specifier = ">=24.2.0" }, + { name = "duckduckgo-search", specifier = ">=8.0.0" }, + { name = "fastapi", specifier = ">=0.110.0" }, + { name = "httpx", specifier = ">=0.28.1" }, + { name = "inquirerpy", specifier = ">=0.3.4" }, + { name = "jinja2", specifier = ">=3.1.3" }, + { name = "json-repair", specifier = ">=0.7.0" }, + { name = "langchain-community", specifier = ">=0.3.19" }, + { name = "langchain-experimental", specifier = ">=0.3.4" }, + { name = "langchain-openai", specifier = ">=0.3.8" }, + { name = "langgraph", specifier = ">=0.3.5" }, + { name = "litellm", specifier = ">=1.63.11" }, + { name = "markdownify", specifier = ">=1.1.0" }, + { name = "mcp", specifier = ">=1.6.0" }, + { name = "numpy", specifier = ">=2.2.3" }, + { name = "pandas", specifier = ">=2.2.3" }, + { name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" }, + { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" }, + { name = "python-dotenv", specifier = ">=1.0.1" }, + { name = "readabilipy", specifier = ">=0.3.0" }, + { name = "socksio", specifier = ">=1.0.0" }, + { name = "sse-starlette", specifier = ">=1.6.5" }, + { name = "uvicorn", specifier = ">=0.27.1" }, + { name = "yfinance", specifier = ">=0.2.54" }, +] +provides-extras = ["dev", "test"] + [[package]] name = "distro" version = "1.9.0" @@ -853,72 +921,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e0/09/3f909694aa0b104a611444959227832206864d92703e191a0f4b2a27d55b/langsmith-0.3.13-py3-none-any.whl", hash = "sha256:73aaf52bbc293b9415fff4f6dad68df40658081eb26c9cb2c7bd1ff57cedd695", size = 339683 }, ] -[[package]] -name = "deer-flow" -version = "0.1.0" -source = { editable = "." } -dependencies = [ - { name = "arxiv" }, - { name = "duckduckgo-search" }, - { name = "fastapi" }, - { name = "httpx" }, - { name = "inquirerpy" }, - { name = "jinja2" }, - { name = "json-repair" }, - { name = "langchain-community" }, - { name = "langchain-experimental" }, - { name = "langchain-openai" }, - { name = "langgraph" }, - { name = "litellm" }, - { name = "markdownify" }, - { name = "numpy" }, - { name = "pandas" }, - { name = "python-dotenv" }, - { name = "readabilipy" }, - { name = "socksio" }, - { name = "sse-starlette" }, - { name = "uvicorn" }, - { name = "yfinance" }, -] - -[package.optional-dependencies] -dev = [ - { name = "black" }, -] -test = [ - { name = "pytest" }, - { name = "pytest-cov" }, -] - -[package.metadata] -requires-dist = [ - { name = "arxiv", specifier = ">=2.2.0" }, - { name = "black", marker = "extra == 'dev'", specifier = ">=24.2.0" }, - { name = "duckduckgo-search", specifier = ">=8.0.0" }, - { name = "fastapi", specifier = ">=0.110.0" }, - { name = "httpx", specifier = ">=0.28.1" }, - { name = "inquirerpy", specifier = ">=0.3.4" }, - { name = "jinja2", specifier = ">=3.1.3" }, - { name = "json-repair", specifier = ">=0.7.0" }, - { name = "langchain-community", specifier = ">=0.3.19" }, - { name = "langchain-experimental", specifier = ">=0.3.4" }, - { name = "langchain-openai", specifier = ">=0.3.8" }, - { name = "langgraph", specifier = ">=0.3.5" }, - { name = "litellm", specifier = ">=1.63.11" }, - { name = "markdownify", specifier = ">=1.1.0" }, - { name = "numpy", specifier = ">=2.2.3" }, - { name = "pandas", specifier = ">=2.2.3" }, - { name = "pytest", marker = "extra == 'test'", specifier = ">=7.4.0" }, - { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" }, - { name = "python-dotenv", specifier = ">=1.0.1" }, - { name = "readabilipy", specifier = ">=0.3.0" }, - { name = "socksio", specifier = ">=1.0.0" }, - { name = "sse-starlette", specifier = ">=1.6.5" }, - { name = "uvicorn", specifier = ">=0.27.1" }, - { name = "yfinance", specifier = ">=0.2.54" }, -] -provides-extras = ["dev", "test"] - [[package]] name = "litellm" version = "1.63.11" @@ -1046,6 +1048,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/75/51952c7b2d3873b44a0028b1bd26a25078c18f92f256608e8d1dc61b39fd/marshmallow-3.26.1-py3-none-any.whl", hash = "sha256:3350409f20a70a7e4e11a27661187b77cdcaeb20abca41c1454fe33636bea09c", size = 50878 }, ] +[[package]] +name = "mcp" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "uvicorn" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 }, +] + [[package]] name = "msgpack" version = "1.1.0"