mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-21 05:14:45 +08:00
feat(mcp): add OAuth support for HTTP/SSE MCP servers (#908)
add oauth schema to MCP server config (extensions_config.json) support client_credentials and refresh_token grants implement token manager with caching and pre-expiry refresh inject OAuth Authorization header for MCP tool discovery and tool calls extend MCP gateway config models to read/write OAuth settings update docs and examples for OAuth configuration add unit tests for token fetch/cache and header injection
This commit is contained in:
150
backend/src/mcp/oauth.py
Normal file
150
backend/src/mcp/oauth.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OAuth token support for MCP HTTP/SSE servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.config.extensions_config import ExtensionsConfig, McpOAuthConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OAuthToken:
|
||||
"""Cached OAuth token."""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class OAuthTokenManager:
|
||||
"""Acquire/cache/refresh OAuth tokens for MCP servers."""
|
||||
|
||||
def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]):
|
||||
self._oauth_by_server = oauth_by_server
|
||||
self._tokens: dict[str, _OAuthToken] = {}
|
||||
self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server}
|
||||
|
||||
@classmethod
|
||||
def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager:
|
||||
oauth_by_server: dict[str, McpOAuthConfig] = {}
|
||||
for server_name, server_config in extensions_config.get_enabled_mcp_servers().items():
|
||||
if server_config.oauth and server_config.oauth.enabled:
|
||||
oauth_by_server[server_name] = server_config.oauth
|
||||
return cls(oauth_by_server)
|
||||
|
||||
def has_oauth_servers(self) -> bool:
|
||||
return bool(self._oauth_by_server)
|
||||
|
||||
def oauth_server_names(self) -> list[str]:
|
||||
return list(self._oauth_by_server.keys())
|
||||
|
||||
async def get_authorization_header(self, server_name: str) -> str | None:
|
||||
oauth = self._oauth_by_server.get(server_name)
|
||||
if not oauth:
|
||||
return None
|
||||
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
lock = self._locks[server_name]
|
||||
async with lock:
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
fresh = await self._fetch_token(oauth)
|
||||
self._tokens[server_name] = fresh
|
||||
logger.info(f"Refreshed OAuth access token for MCP server: {server_name}")
|
||||
return f"{fresh.token_type} {fresh.access_token}"
|
||||
|
||||
@staticmethod
|
||||
def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool:
|
||||
now = datetime.now(UTC)
|
||||
return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0))
|
||||
|
||||
async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken:
|
||||
import httpx # pyright: ignore[reportMissingImports]
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": oauth.grant_type,
|
||||
**oauth.extra_token_params,
|
||||
}
|
||||
|
||||
if oauth.scope:
|
||||
data["scope"] = oauth.scope
|
||||
if oauth.audience:
|
||||
data["audience"] = oauth.audience
|
||||
|
||||
if oauth.grant_type == "client_credentials":
|
||||
if not oauth.client_id or not oauth.client_secret:
|
||||
raise ValueError("OAuth client_credentials requires client_id and client_secret")
|
||||
data["client_id"] = oauth.client_id
|
||||
data["client_secret"] = oauth.client_secret
|
||||
elif oauth.grant_type == "refresh_token":
|
||||
if not oauth.refresh_token:
|
||||
raise ValueError("OAuth refresh_token grant requires refresh_token")
|
||||
data["refresh_token"] = oauth.refresh_token
|
||||
if oauth.client_id:
|
||||
data["client_id"] = oauth.client_id
|
||||
if oauth.client_secret:
|
||||
data["client_secret"] = oauth.client_secret
|
||||
else:
|
||||
raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
response = await client.post(oauth.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
access_token = payload.get(oauth.token_field)
|
||||
if not access_token:
|
||||
raise ValueError(f"OAuth token response missing '{oauth.token_field}'")
|
||||
|
||||
token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type)
|
||||
|
||||
expires_in_raw = payload.get(oauth.expires_in_field, 3600)
|
||||
try:
|
||||
expires_in = int(expires_in_raw)
|
||||
except (TypeError, ValueError):
|
||||
expires_in = 3600
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1))
|
||||
return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at)
|
||||
|
||||
|
||||
def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None:
|
||||
"""Build a tool interceptor that injects OAuth Authorization headers."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return None
|
||||
|
||||
async def oauth_interceptor(request: Any, handler: Any) -> Any:
|
||||
header = await token_manager.get_authorization_header(request.server_name)
|
||||
if not header:
|
||||
return await handler(request)
|
||||
|
||||
updated_headers = dict(request.headers or {})
|
||||
updated_headers["Authorization"] = header
|
||||
return await handler(request.override(headers=updated_headers))
|
||||
|
||||
return oauth_interceptor
|
||||
|
||||
|
||||
async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]:
|
||||
"""Get initial OAuth Authorization headers for MCP server connections."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return {}
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
for server_name in token_manager.oauth_server_names():
|
||||
headers[server_name] = await token_manager.get_authorization_header(server_name) or ""
|
||||
|
||||
return {name: value for name, value in headers.items() if value}
|
||||
@@ -6,6 +6,7 @@ from langchain_core.tools import BaseTool
|
||||
|
||||
from src.config.extensions_config import ExtensionsConfig
|
||||
from src.mcp.client import build_servers_config
|
||||
from src.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +37,23 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
try:
|
||||
# Create the multi-server MCP client
|
||||
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
|
||||
client = MultiServerMCPClient(servers_config)
|
||||
|
||||
# Inject initial OAuth headers for server connections (tool discovery/session init)
|
||||
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
|
||||
for server_name, auth_header in initial_oauth_headers.items():
|
||||
if server_name not in servers_config:
|
||||
continue
|
||||
if servers_config[server_name].get("transport") in ("sse", "http"):
|
||||
existing_headers = dict(servers_config[server_name].get("headers", {}))
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors)
|
||||
|
||||
# Get all tools from all servers
|
||||
tools = await client.get_tools()
|
||||
|
||||
Reference in New Issue
Block a user