mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
feat(mcp): add OAuth support for HTTP/SSE MCP servers (#908)
add oauth schema to MCP server config (extensions_config.json) support client_credentials and refresh_token grants implement token manager with caching and pre-expiry refresh inject OAuth Authorization header for MCP tool discovery and tool calls extend MCP gateway config models to read/write OAuth settings update docs and examples for OAuth configuration add unit tests for token fetch/cache and header injection
This commit is contained in:
41
README.md
41
README.md
@@ -20,21 +20,31 @@ Learn more and see **real demos** on our official website.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Quick Start](#quick-start)
|
||||
- [Sandbox Mode](#sandbox-mode)
|
||||
- [From Deep Research to Super Agent Harness](#from-deep-research-to-super-agent-harness)
|
||||
- [Core Features](#core-features)
|
||||
- [Skills & Tools](#skills--tools)
|
||||
- [Sub-Agents](#sub-agents)
|
||||
- [Sandbox & File System](#sandbox--file-system)
|
||||
- [Context Engineering](#context-engineering)
|
||||
- [Long-Term Memory](#long-term-memory)
|
||||
- [Recommended Models](#recommended-models)
|
||||
- [Documentation](#documentation)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
- [Acknowledgments](#acknowledgments)
|
||||
- [Star History](#star-history)
|
||||
- [🦌 DeerFlow - 2.0](#-deerflow---20)
|
||||
- [Offiical Website](#offiical-website)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Configuration](#configuration)
|
||||
- [Running the Application](#running-the-application)
|
||||
- [Option 1: Docker (Recommended)](#option-1-docker-recommended)
|
||||
- [Option 2: Local Development](#option-2-local-development)
|
||||
- [Advanced](#advanced)
|
||||
- [Sandbox Mode](#sandbox-mode)
|
||||
- [MCP Server](#mcp-server)
|
||||
- [From Deep Research to Super Agent Harness](#from-deep-research-to-super-agent-harness)
|
||||
- [Core Features](#core-features)
|
||||
- [Skills \& Tools](#skills--tools)
|
||||
- [Sub-Agents](#sub-agents)
|
||||
- [Sandbox \& File System](#sandbox--file-system)
|
||||
- [Context Engineering](#context-engineering)
|
||||
- [Long-Term Memory](#long-term-memory)
|
||||
- [Recommended Models](#recommended-models)
|
||||
- [Documentation](#documentation)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
- [Acknowledgments](#acknowledgments)
|
||||
- [Key Contributors](#key-contributors)
|
||||
- [Star History](#star-history)
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -155,6 +165,7 @@ See the [Sandbox Configuration Guide](backend/docs/CONFIGURATION.md#sandbox) to
|
||||
#### MCP Server
|
||||
|
||||
DeerFlow supports configurable MCP servers and skills to extend its capabilities.
|
||||
For HTTP/SSE MCP servers, OAuth token flows are supported (`client_credentials`, `refresh_token`).
|
||||
See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions.
|
||||
|
||||
## From Deep Research to Super Agent Harness
|
||||
|
||||
@@ -224,6 +224,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- **Lazy initialization**: Tools loaded on first use via `get_cached_mcp_tools()`
|
||||
- **Cache invalidation**: Detects config file changes via mtime comparison
|
||||
- **Transports**: stdio (command-based), SSE, HTTP
|
||||
- **OAuth (HTTP/SSE)**: Supports token endpoint flows (`client_credentials`, `refresh_token`) with automatic token refresh + Authorization header injection
|
||||
- **Runtime updates**: Gateway API saves to extensions_config.json; LangGraph detects via mtime
|
||||
|
||||
### Skills System (`src/skills/`)
|
||||
@@ -287,7 +288,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
|
||||
- `memory` - Memory system (enabled, storage_path, debounce_seconds, model_name, max_facts, fact_confidence_threshold, injection_enabled, max_injection_tokens)
|
||||
|
||||
**`extensions_config.json`**:
|
||||
- `mcpServers` - Map of server name → config (enabled, type, command, args, env, url, headers, description)
|
||||
- `mcpServers` - Map of server name → config (enabled, type, command, args, env, url, headers, oauth, description)
|
||||
- `skills` - Map of skill name → state (enabled)
|
||||
|
||||
Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` methods.
|
||||
|
||||
@@ -264,6 +264,18 @@ MCP servers and skill states in a single file:
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
||||
"env": {"GITHUB_TOKEN": "$GITHUB_TOKEN"}
|
||||
},
|
||||
"secure-http": {
|
||||
"enabled": true,
|
||||
"type": "http",
|
||||
"url": "https://api.example.com/mcp",
|
||||
"oauth": {
|
||||
"enabled": true,
|
||||
"token_url": "https://auth.example.com/oauth/token",
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "$MCP_OAUTH_CLIENT_ID",
|
||||
"client_secret": "$MCP_OAUTH_CLIENT_SECRET"
|
||||
}
|
||||
}
|
||||
},
|
||||
"skills": {
|
||||
|
||||
@@ -503,6 +503,8 @@ All APIs return errors in a consistent format:
|
||||
|
||||
Currently, DeerFlow does not implement authentication. All APIs are accessible without credentials.
|
||||
|
||||
Note: This is about DeerFlow API authentication. MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers.
|
||||
|
||||
For production deployments, it is recommended to:
|
||||
1. Use Nginx for basic auth or OAuth integration
|
||||
2. Deploy behind a VPN or private network
|
||||
|
||||
@@ -14,6 +14,37 @@ DeerFlow supports configurable MCP servers and skills to extend its capabilities
|
||||
3. Configure each server’s command, arguments, and environment variables as needed.
|
||||
4. Restart the application to load and register MCP tools.
|
||||
|
||||
## OAuth Support (HTTP/SSE MCP Servers)
|
||||
|
||||
For `http` and `sse` MCP servers, DeerFlow supports OAuth token acquisition and automatic token refresh.
|
||||
|
||||
- Supported grants: `client_credentials`, `refresh_token`
|
||||
- Configure per-server `oauth` block in `extensions_config.json`
|
||||
- Secrets should be provided via environment variables (for example: `$MCP_OAUTH_CLIENT_SECRET`)
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"secure-http-server": {
|
||||
"enabled": true,
|
||||
"type": "http",
|
||||
"url": "https://api.example.com/mcp",
|
||||
"oauth": {
|
||||
"enabled": true,
|
||||
"token_url": "https://auth.example.com/oauth/token",
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "$MCP_OAUTH_CLIENT_ID",
|
||||
"client_secret": "$MCP_OAUTH_CLIENT_SECRET",
|
||||
"scope": "mcp.read",
|
||||
"refresh_skew_seconds": 60
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
MCP servers expose tools that are automatically discovered and integrated into DeerFlow’s agent system at runtime. Once enabled, these tools become available to agents without additional code changes.
|
||||
|
||||
@@ -3,11 +3,34 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class McpOAuthConfig(BaseModel):
|
||||
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
|
||||
token_url: str = Field(description="OAuth token endpoint URL")
|
||||
grant_type: Literal["client_credentials", "refresh_token"] = Field(
|
||||
default="client_credentials",
|
||||
description="OAuth grant type",
|
||||
)
|
||||
client_id: str | None = Field(default=None, description="OAuth client ID")
|
||||
client_secret: str | None = Field(default=None, description="OAuth client secret")
|
||||
refresh_token: str | None = Field(default=None, description="OAuth refresh token (for refresh_token grant)")
|
||||
scope: str | None = Field(default=None, description="OAuth scope")
|
||||
audience: str | None = Field(default=None, description="OAuth audience (provider-specific)")
|
||||
token_field: str = Field(default="access_token", description="Field name containing access token in token response")
|
||||
token_type_field: str = Field(default="token_type", description="Field name containing token type in token response")
|
||||
expires_in_field: str = Field(default="expires_in", description="Field name containing expiry (seconds) in token response")
|
||||
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
|
||||
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
|
||||
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class McpServerConfig(BaseModel):
|
||||
"""Configuration for a single MCP server."""
|
||||
|
||||
@@ -18,6 +41,7 @@ class McpServerConfig(BaseModel):
|
||||
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
|
||||
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
|
||||
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
|
||||
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
|
||||
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -11,6 +12,25 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["mcp"])
|
||||
|
||||
|
||||
class McpOAuthConfigResponse(BaseModel):
|
||||
"""OAuth configuration for an MCP server."""
|
||||
|
||||
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
|
||||
token_url: str = Field(default="", description="OAuth token endpoint URL")
|
||||
grant_type: Literal["client_credentials", "refresh_token"] = Field(default="client_credentials", description="OAuth grant type")
|
||||
client_id: str | None = Field(default=None, description="OAuth client ID")
|
||||
client_secret: str | None = Field(default=None, description="OAuth client secret")
|
||||
refresh_token: str | None = Field(default=None, description="OAuth refresh token")
|
||||
scope: str | None = Field(default=None, description="OAuth scope")
|
||||
audience: str | None = Field(default=None, description="OAuth audience")
|
||||
token_field: str = Field(default="access_token", description="Token response field containing access token")
|
||||
token_type_field: str = Field(default="token_type", description="Token response field containing token type")
|
||||
expires_in_field: str = Field(default="expires_in", description="Token response field containing expires-in seconds")
|
||||
default_token_type: str = Field(default="Bearer", description="Default token type when response omits token_type")
|
||||
refresh_skew_seconds: int = Field(default=60, description="Refresh this many seconds before expiry")
|
||||
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
|
||||
|
||||
|
||||
class McpServerConfigResponse(BaseModel):
|
||||
"""Response model for MCP server configuration."""
|
||||
|
||||
@@ -21,6 +41,7 @@ class McpServerConfigResponse(BaseModel):
|
||||
env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server")
|
||||
url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)")
|
||||
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
|
||||
oauth: McpOAuthConfigResponse | None = Field(default=None, description="OAuth configuration for MCP HTTP/SSE servers")
|
||||
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
|
||||
|
||||
|
||||
|
||||
150
backend/src/mcp/oauth.py
Normal file
150
backend/src/mcp/oauth.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""OAuth token support for MCP HTTP/SSE servers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.config.extensions_config import ExtensionsConfig, McpOAuthConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OAuthToken:
|
||||
"""Cached OAuth token."""
|
||||
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class OAuthTokenManager:
|
||||
"""Acquire/cache/refresh OAuth tokens for MCP servers."""
|
||||
|
||||
def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]):
|
||||
self._oauth_by_server = oauth_by_server
|
||||
self._tokens: dict[str, _OAuthToken] = {}
|
||||
self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server}
|
||||
|
||||
@classmethod
|
||||
def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager:
|
||||
oauth_by_server: dict[str, McpOAuthConfig] = {}
|
||||
for server_name, server_config in extensions_config.get_enabled_mcp_servers().items():
|
||||
if server_config.oauth and server_config.oauth.enabled:
|
||||
oauth_by_server[server_name] = server_config.oauth
|
||||
return cls(oauth_by_server)
|
||||
|
||||
def has_oauth_servers(self) -> bool:
|
||||
return bool(self._oauth_by_server)
|
||||
|
||||
def oauth_server_names(self) -> list[str]:
|
||||
return list(self._oauth_by_server.keys())
|
||||
|
||||
async def get_authorization_header(self, server_name: str) -> str | None:
|
||||
oauth = self._oauth_by_server.get(server_name)
|
||||
if not oauth:
|
||||
return None
|
||||
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
lock = self._locks[server_name]
|
||||
async with lock:
|
||||
token = self._tokens.get(server_name)
|
||||
if token and not self._is_expiring(token, oauth):
|
||||
return f"{token.token_type} {token.access_token}"
|
||||
|
||||
fresh = await self._fetch_token(oauth)
|
||||
self._tokens[server_name] = fresh
|
||||
logger.info(f"Refreshed OAuth access token for MCP server: {server_name}")
|
||||
return f"{fresh.token_type} {fresh.access_token}"
|
||||
|
||||
@staticmethod
|
||||
def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool:
|
||||
now = datetime.now(UTC)
|
||||
return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0))
|
||||
|
||||
async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken:
|
||||
import httpx # pyright: ignore[reportMissingImports]
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": oauth.grant_type,
|
||||
**oauth.extra_token_params,
|
||||
}
|
||||
|
||||
if oauth.scope:
|
||||
data["scope"] = oauth.scope
|
||||
if oauth.audience:
|
||||
data["audience"] = oauth.audience
|
||||
|
||||
if oauth.grant_type == "client_credentials":
|
||||
if not oauth.client_id or not oauth.client_secret:
|
||||
raise ValueError("OAuth client_credentials requires client_id and client_secret")
|
||||
data["client_id"] = oauth.client_id
|
||||
data["client_secret"] = oauth.client_secret
|
||||
elif oauth.grant_type == "refresh_token":
|
||||
if not oauth.refresh_token:
|
||||
raise ValueError("OAuth refresh_token grant requires refresh_token")
|
||||
data["refresh_token"] = oauth.refresh_token
|
||||
if oauth.client_id:
|
||||
data["client_id"] = oauth.client_id
|
||||
if oauth.client_secret:
|
||||
data["client_secret"] = oauth.client_secret
|
||||
else:
|
||||
raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
response = await client.post(oauth.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
access_token = payload.get(oauth.token_field)
|
||||
if not access_token:
|
||||
raise ValueError(f"OAuth token response missing '{oauth.token_field}'")
|
||||
|
||||
token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type)
|
||||
|
||||
expires_in_raw = payload.get(oauth.expires_in_field, 3600)
|
||||
try:
|
||||
expires_in = int(expires_in_raw)
|
||||
except (TypeError, ValueError):
|
||||
expires_in = 3600
|
||||
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1))
|
||||
return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at)
|
||||
|
||||
|
||||
def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None:
|
||||
"""Build a tool interceptor that injects OAuth Authorization headers."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return None
|
||||
|
||||
async def oauth_interceptor(request: Any, handler: Any) -> Any:
|
||||
header = await token_manager.get_authorization_header(request.server_name)
|
||||
if not header:
|
||||
return await handler(request)
|
||||
|
||||
updated_headers = dict(request.headers or {})
|
||||
updated_headers["Authorization"] = header
|
||||
return await handler(request.override(headers=updated_headers))
|
||||
|
||||
return oauth_interceptor
|
||||
|
||||
|
||||
async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]:
|
||||
"""Get initial OAuth Authorization headers for MCP server connections."""
|
||||
token_manager = OAuthTokenManager.from_extensions_config(extensions_config)
|
||||
if not token_manager.has_oauth_servers():
|
||||
return {}
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
for server_name in token_manager.oauth_server_names():
|
||||
headers[server_name] = await token_manager.get_authorization_header(server_name) or ""
|
||||
|
||||
return {name: value for name, value in headers.items() if value}
|
||||
@@ -6,6 +6,7 @@ from langchain_core.tools import BaseTool
|
||||
|
||||
from src.config.extensions_config import ExtensionsConfig
|
||||
from src.mcp.client import build_servers_config
|
||||
from src.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,7 +37,23 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
try:
|
||||
# Create the multi-server MCP client
|
||||
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
|
||||
client = MultiServerMCPClient(servers_config)
|
||||
|
||||
# Inject initial OAuth headers for server connections (tool discovery/session init)
|
||||
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
|
||||
for server_name, auth_header in initial_oauth_headers.items():
|
||||
if server_name not in servers_config:
|
||||
continue
|
||||
if servers_config[server_name].get("transport") in ("sse", "http"):
|
||||
existing_headers = dict(servers_config[server_name].get("headers", {}))
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors)
|
||||
|
||||
# Get all tools from all servers
|
||||
tools = await client.get_tools()
|
||||
|
||||
191
backend/tests/test_mcp_oauth.py
Normal file
191
backend/tests/test_mcp_oauth.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Tests for MCP OAuth support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from src.config.extensions_config import ExtensionsConfig
|
||||
from src.mcp.oauth import OAuthTokenManager, build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
|
||||
|
||||
class _MockResponse:
|
||||
def __init__(self, payload: dict[str, Any]):
|
||||
self._payload = payload
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return None
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self._payload
|
||||
|
||||
|
||||
class _MockAsyncClient:
|
||||
def __init__(self, payload: dict[str, Any], post_calls: list[dict[str, Any]], **kwargs):
|
||||
self._payload = payload
|
||||
self._post_calls = post_calls
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def post(self, url: str, data: dict[str, Any]):
|
||||
self._post_calls.append({"url": url, "data": data})
|
||||
return _MockResponse(self._payload)
|
||||
|
||||
|
||||
def test_oauth_token_manager_fetches_and_caches_token(monkeypatch):
|
||||
post_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
return _MockAsyncClient(
|
||||
payload={
|
||||
"access_token": "token-123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
},
|
||||
post_calls=post_calls,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("httpx.AsyncClient", _client_factory)
|
||||
|
||||
config = ExtensionsConfig.model_validate(
|
||||
{
|
||||
"mcpServers": {
|
||||
"secure-http": {
|
||||
"enabled": True,
|
||||
"type": "http",
|
||||
"url": "https://api.example.com/mcp",
|
||||
"oauth": {
|
||||
"enabled": True,
|
||||
"token_url": "https://auth.example.com/oauth/token",
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
manager = OAuthTokenManager.from_extensions_config(config)
|
||||
|
||||
first = asyncio.run(manager.get_authorization_header("secure-http"))
|
||||
second = asyncio.run(manager.get_authorization_header("secure-http"))
|
||||
|
||||
assert first == "Bearer token-123"
|
||||
assert second == "Bearer token-123"
|
||||
assert len(post_calls) == 1
|
||||
assert post_calls[0]["url"] == "https://auth.example.com/oauth/token"
|
||||
assert post_calls[0]["data"]["grant_type"] == "client_credentials"
|
||||
|
||||
|
||||
def test_build_oauth_interceptor_injects_authorization_header(monkeypatch):
|
||||
post_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
return _MockAsyncClient(
|
||||
payload={
|
||||
"access_token": "token-abc",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
},
|
||||
post_calls=post_calls,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("httpx.AsyncClient", _client_factory)
|
||||
|
||||
config = ExtensionsConfig.model_validate(
|
||||
{
|
||||
"mcpServers": {
|
||||
"secure-sse": {
|
||||
"enabled": True,
|
||||
"type": "sse",
|
||||
"url": "https://api.example.com/mcp",
|
||||
"oauth": {
|
||||
"enabled": True,
|
||||
"token_url": "https://auth.example.com/oauth/token",
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
interceptor = build_oauth_tool_interceptor(config)
|
||||
assert interceptor is not None
|
||||
|
||||
class _Request:
|
||||
def __init__(self):
|
||||
self.server_name = "secure-sse"
|
||||
self.headers = {"X-Test": "1"}
|
||||
|
||||
def override(self, **kwargs):
|
||||
updated = _Request()
|
||||
updated.server_name = self.server_name
|
||||
updated.headers = kwargs.get("headers")
|
||||
return updated
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
async def _handler(request):
|
||||
captured["headers"] = request.headers
|
||||
return "ok"
|
||||
|
||||
result = asyncio.run(interceptor(_Request(), _handler))
|
||||
|
||||
assert result == "ok"
|
||||
assert captured["headers"]["Authorization"] == "Bearer token-abc"
|
||||
assert captured["headers"]["X-Test"] == "1"
|
||||
|
||||
|
||||
def test_get_initial_oauth_headers(monkeypatch):
|
||||
post_calls: list[dict[str, Any]] = []
|
||||
|
||||
def _client_factory(*args, **kwargs):
|
||||
return _MockAsyncClient(
|
||||
payload={
|
||||
"access_token": "token-initial",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
},
|
||||
post_calls=post_calls,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("httpx.AsyncClient", _client_factory)
|
||||
|
||||
config = ExtensionsConfig.model_validate(
|
||||
{
|
||||
"mcpServers": {
|
||||
"secure-http": {
|
||||
"enabled": True,
|
||||
"type": "http",
|
||||
"url": "https://api.example.com/mcp",
|
||||
"oauth": {
|
||||
"enabled": True,
|
||||
"token_url": "https://auth.example.com/oauth/token",
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
},
|
||||
},
|
||||
"no-oauth": {
|
||||
"enabled": True,
|
||||
"type": "http",
|
||||
"url": "https://example.com/mcp",
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
headers = asyncio.run(get_initial_oauth_headers(config))
|
||||
|
||||
assert headers == {"secure-http": "Bearer token-initial"}
|
||||
assert len(post_calls) == 1
|
||||
@@ -32,7 +32,17 @@
|
||||
"headers": {
|
||||
"Authorization": "Bearer $API_TOKEN",
|
||||
"X-Custom-Header": "value"
|
||||
}
|
||||
},
|
||||
"oauth": {
|
||||
"enabled": true,
|
||||
"token_url": "https://auth.example.com/oauth/token",
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "$MCP_OAUTH_CLIENT_ID",
|
||||
"client_secret": "$MCP_OAUTH_CLIENT_SECRET",
|
||||
"scope": "mcp.read mcp.write",
|
||||
"audience": "https://api.example.com",
|
||||
"refresh_skew_seconds": 60
|
||||
}
|
||||
},
|
||||
"my-http-server": {
|
||||
"type": "http",
|
||||
@@ -40,7 +50,14 @@
|
||||
"headers": {
|
||||
"Authorization": "Bearer $API_TOKEN",
|
||||
"X-Custom-Header": "value"
|
||||
}
|
||||
},
|
||||
"oauth": {
|
||||
"enabled": true,
|
||||
"token_url": "https://auth.example.com/oauth/token",
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": "$MCP_OAUTH_CLIENT_ID",
|
||||
"client_secret": "$MCP_OAUTH_CLIENT_SECRET"
|
||||
}
|
||||
}
|
||||
},
|
||||
"skills": {
|
||||
|
||||
Reference in New Issue
Block a user