feat(mcp): add OAuth support for HTTP/SSE MCP servers (#908)

add oauth schema to MCP server config (extensions_config.json)
support client_credentials and refresh_token grants
implement token manager with caching and pre-expiry refresh
inject OAuth Authorization header for MCP tool discovery and tool calls
extend MCP gateway config models to read/write OAuth settings
update docs and examples for OAuth configuration
add unit tests for token fetch/cache and header injection
This commit is contained in:
Willem Jiang
2026-03-01 22:38:58 +08:00
committed by GitHub
parent 80316c131e
commit a2f91c7594
11 changed files with 497 additions and 20 deletions

View File

@@ -3,11 +3,34 @@
import json
import os
from pathlib import Path
from typing import Any
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(
default="client_credentials",
description="OAuth grant type",
)
client_id: str | None = Field(default=None, description="OAuth client ID")
client_secret: str | None = Field(default=None, description="OAuth client secret")
refresh_token: str | None = Field(default=None, description="OAuth refresh token (for refresh_token grant)")
scope: str | None = Field(default=None, description="OAuth scope")
audience: str | None = Field(default=None, description="OAuth audience (provider-specific)")
token_field: str = Field(default="access_token", description="Field name containing access token in token response")
token_type_field: str = Field(default="token_type", description="Field name containing token type in token response")
expires_in_field: str = Field(default="expires_in", description="Field name containing expiry (seconds) in token response")
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel):
"""Configuration for a single MCP server."""
@@ -18,6 +41,7 @@ class McpServerConfig(BaseModel):
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")

View File

@@ -1,6 +1,7 @@
import json
import logging
from pathlib import Path
from typing import Literal
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
@@ -11,6 +12,25 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
class McpOAuthConfigResponse(BaseModel):
"""OAuth configuration for an MCP server."""
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(default="", description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(default="client_credentials", description="OAuth grant type")
client_id: str | None = Field(default=None, description="OAuth client ID")
client_secret: str | None = Field(default=None, description="OAuth client secret")
refresh_token: str | None = Field(default=None, description="OAuth refresh token")
scope: str | None = Field(default=None, description="OAuth scope")
audience: str | None = Field(default=None, description="OAuth audience")
token_field: str = Field(default="access_token", description="Token response field containing access token")
token_type_field: str = Field(default="token_type", description="Token response field containing token type")
expires_in_field: str = Field(default="expires_in", description="Token response field containing expires-in seconds")
default_token_type: str = Field(default="Bearer", description="Default token type when response omits token_type")
refresh_skew_seconds: int = Field(default=60, description="Refresh this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
class McpServerConfigResponse(BaseModel):
"""Response model for MCP server configuration."""
@@ -21,6 +41,7 @@ class McpServerConfigResponse(BaseModel):
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfigResponse | None = Field(default=None, description="OAuth configuration for MCP HTTP/SSE servers")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")

150
backend/src/mcp/oauth.py Normal file
View File

@@ -0,0 +1,150 @@
"""OAuth token support for MCP HTTP/SSE servers."""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
from src.config.extensions_config import ExtensionsConfig, McpOAuthConfig
logger = logging.getLogger(__name__)
@dataclass
class _OAuthToken:
"""Cached OAuth token."""
access_token: str
token_type: str
expires_at: datetime
class OAuthTokenManager:
"""Acquire/cache/refresh OAuth tokens for MCP servers."""
def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]):
self._oauth_by_server = oauth_by_server
self._tokens: dict[str, _OAuthToken] = {}
self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server}
@classmethod
def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager:
oauth_by_server: dict[str, McpOAuthConfig] = {}
for server_name, server_config in extensions_config.get_enabled_mcp_servers().items():
if server_config.oauth and server_config.oauth.enabled:
oauth_by_server[server_name] = server_config.oauth
return cls(oauth_by_server)
def has_oauth_servers(self) -> bool:
return bool(self._oauth_by_server)
def oauth_server_names(self) -> list[str]:
return list(self._oauth_by_server.keys())
async def get_authorization_header(self, server_name: str) -> str | None:
oauth = self._oauth_by_server.get(server_name)
if not oauth:
return None
token = self._tokens.get(server_name)
if token and not self._is_expiring(token, oauth):
return f"{token.token_type} {token.access_token}"
lock = self._locks[server_name]
async with lock:
token = self._tokens.get(server_name)
if token and not self._is_expiring(token, oauth):
return f"{token.token_type} {token.access_token}"
fresh = await self._fetch_token(oauth)
self._tokens[server_name] = fresh
logger.info(f"Refreshed OAuth access token for MCP server: {server_name}")
return f"{fresh.token_type} {fresh.access_token}"
@staticmethod
def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool:
now = datetime.now(UTC)
return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0))
async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken:
import httpx # pyright: ignore[reportMissingImports]
data: dict[str, str] = {
"grant_type": oauth.grant_type,
**oauth.extra_token_params,
}
if oauth.scope:
data["scope"] = oauth.scope
if oauth.audience:
data["audience"] = oauth.audience
if oauth.grant_type == "client_credentials":
if not oauth.client_id or not oauth.client_secret:
raise ValueError("OAuth client_credentials requires client_id and client_secret")
data["client_id"] = oauth.client_id
data["client_secret"] = oauth.client_secret
elif oauth.grant_type == "refresh_token":
if not oauth.refresh_token:
raise ValueError("OAuth refresh_token grant requires refresh_token")
data["refresh_token"] = oauth.refresh_token
if oauth.client_id:
data["client_id"] = oauth.client_id
if oauth.client_secret:
data["client_secret"] = oauth.client_secret
else:
raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}")
async with httpx.AsyncClient(timeout=15.0) as client:
response = await client.post(oauth.token_url, data=data)
response.raise_for_status()
payload = response.json()
access_token = payload.get(oauth.token_field)
if not access_token:
raise ValueError(f"OAuth token response missing '{oauth.token_field}'")
token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type)
expires_in_raw = payload.get(oauth.expires_in_field, 3600)
try:
expires_in = int(expires_in_raw)
except (TypeError, ValueError):
expires_in = 3600
expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1))
return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at)
def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None:
"""Build a tool interceptor that injects OAuth Authorization headers."""
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
if not token_manager.has_oauth_servers():
return None
async def oauth_interceptor(request: Any, handler: Any) -> Any:
header = await token_manager.get_authorization_header(request.server_name)
if not header:
return await handler(request)
updated_headers = dict(request.headers or {})
updated_headers["Authorization"] = header
return await handler(request.override(headers=updated_headers))
return oauth_interceptor
async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]:
"""Get initial OAuth Authorization headers for MCP server connections."""
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
if not token_manager.has_oauth_servers():
return {}
headers: dict[str, str] = {}
for server_name in token_manager.oauth_server_names():
headers[server_name] = await token_manager.get_authorization_header(server_name) or ""
return {name: value for name, value in headers.items() if value}

View File

@@ -6,6 +6,7 @@ from langchain_core.tools import BaseTool
from src.config.extensions_config import ExtensionsConfig
from src.mcp.client import build_servers_config
from src.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
logger = logging.getLogger(__name__)
@@ -36,7 +37,23 @@ async def get_mcp_tools() -> list[BaseTool]:
try:
# Create the multi-server MCP client
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
client = MultiServerMCPClient(servers_config)
# Inject initial OAuth headers for server connections (tool discovery/session init)
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
for server_name, auth_header in initial_oauth_headers.items():
if server_name not in servers_config:
continue
if servers_config[server_name].get("transport") in ("sse", "http"):
existing_headers = dict(servers_config[server_name].get("headers", {}))
existing_headers["Authorization"] = auth_header
servers_config[server_name]["headers"] = existing_headers
tool_interceptors = []
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors)
# Get all tools from all servers
tools = await client.get_tools()