diff --git a/README.md b/README.md index 45a35af..59e4007 100644 --- a/README.md +++ b/README.md @@ -20,21 +20,31 @@ Learn more and see **real demos** on our official website. ## Table of Contents -- [Quick Start](#quick-start) -- [Sandbox Mode](#sandbox-mode) -- [From Deep Research to Super Agent Harness](#from-deep-research-to-super-agent-harness) -- [Core Features](#core-features) - - [Skills & Tools](#skills--tools) - - [Sub-Agents](#sub-agents) - - [Sandbox & File System](#sandbox--file-system) - - [Context Engineering](#context-engineering) - - [Long-Term Memory](#long-term-memory) -- [Recommended Models](#recommended-models) -- [Documentation](#documentation) -- [Contributing](#contributing) -- [License](#license) -- [Acknowledgments](#acknowledgments) -- [Star History](#star-history) +- [🦌 DeerFlow - 2.0](#-deerflow---20) + - [Offiical Website](#offiical-website) + - [Table of Contents](#table-of-contents) + - [Quick Start](#quick-start) + - [Configuration](#configuration) + - [Running the Application](#running-the-application) + - [Option 1: Docker (Recommended)](#option-1-docker-recommended) + - [Option 2: Local Development](#option-2-local-development) + - [Advanced](#advanced) + - [Sandbox Mode](#sandbox-mode) + - [MCP Server](#mcp-server) + - [From Deep Research to Super Agent Harness](#from-deep-research-to-super-agent-harness) + - [Core Features](#core-features) + - [Skills \& Tools](#skills--tools) + - [Sub-Agents](#sub-agents) + - [Sandbox \& File System](#sandbox--file-system) + - [Context Engineering](#context-engineering) + - [Long-Term Memory](#long-term-memory) + - [Recommended Models](#recommended-models) + - [Documentation](#documentation) + - [Contributing](#contributing) + - [License](#license) + - [Acknowledgments](#acknowledgments) + - [Key Contributors](#key-contributors) + - [Star History](#star-history) ## Quick Start @@ -155,6 +165,7 @@ See the [Sandbox Configuration Guide](backend/docs/CONFIGURATION.md#sandbox) to #### MCP Server DeerFlow supports configurable MCP servers and skills to extend its capabilities. +For HTTP/SSE MCP servers, OAuth token flows are supported (`client_credentials`, `refresh_token`). See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions. ## From Deep Research to Super Agent Harness diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 31b1728..25e0139 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -224,6 +224,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → - **Lazy initialization**: Tools loaded on first use via `get_cached_mcp_tools()` - **Cache invalidation**: Detects config file changes via mtime comparison - **Transports**: stdio (command-based), SSE, HTTP +- **OAuth (HTTP/SSE)**: Supports token endpoint flows (`client_credentials`, `refresh_token`) with automatic token refresh + Authorization header injection - **Runtime updates**: Gateway API saves to extensions_config.json; LangGraph detects via mtime ### Skills System (`src/skills/`) @@ -287,7 +288,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → - `memory` - Memory system (enabled, storage_path, debounce_seconds, model_name, max_facts, fact_confidence_threshold, injection_enabled, max_injection_tokens) **`extensions_config.json`**: -- `mcpServers` - Map of server name → config (enabled, type, command, args, env, url, headers, description) +- `mcpServers` - Map of server name → config (enabled, type, command, args, env, url, headers, oauth, description) - `skills` - Map of skill name → state (enabled) Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` methods. diff --git a/backend/README.md b/backend/README.md index d6c231f..1b96eb6 100644 --- a/backend/README.md +++ b/backend/README.md @@ -264,6 +264,18 @@ MCP servers and skill states in a single file: "command": "npx", "args": ["-y", "@modelcontextprotocol/server-github"], "env": {"GITHUB_TOKEN": "$GITHUB_TOKEN"} + }, + "secure-http": { + "enabled": true, + "type": "http", + "url": "https://api.example.com/mcp", + "oauth": { + "enabled": true, + "token_url": "https://auth.example.com/oauth/token", + "grant_type": "client_credentials", + "client_id": "$MCP_OAUTH_CLIENT_ID", + "client_secret": "$MCP_OAUTH_CLIENT_SECRET" + } } }, "skills": { diff --git a/backend/docs/API.md b/backend/docs/API.md index 358257d..6d2255a 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -503,6 +503,8 @@ All APIs return errors in a consistent format: Currently, DeerFlow does not implement authentication. All APIs are accessible without credentials. +Note: This is about DeerFlow API authentication. MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers. + For production deployments, it is recommended to: 1. Use Nginx for basic auth or OAuth integration 2. Deploy behind a VPN or private network diff --git a/backend/docs/MCP_SERVER.md b/backend/docs/MCP_SERVER.md index 4fbd727..efe2ea0 100644 --- a/backend/docs/MCP_SERVER.md +++ b/backend/docs/MCP_SERVER.md @@ -14,6 +14,37 @@ DeerFlow supports configurable MCP servers and skills to extend its capabilities 3. Configure each server’s command, arguments, and environment variables as needed. 4. Restart the application to load and register MCP tools. +## OAuth Support (HTTP/SSE MCP Servers) + +For `http` and `sse` MCP servers, DeerFlow supports OAuth token acquisition and automatic token refresh. + +- Supported grants: `client_credentials`, `refresh_token` +- Configure per-server `oauth` block in `extensions_config.json` +- Secrets should be provided via environment variables (for example: `$MCP_OAUTH_CLIENT_SECRET`) + +Example: + +```json +{ + "mcpServers": { + "secure-http-server": { + "enabled": true, + "type": "http", + "url": "https://api.example.com/mcp", + "oauth": { + "enabled": true, + "token_url": "https://auth.example.com/oauth/token", + "grant_type": "client_credentials", + "client_id": "$MCP_OAUTH_CLIENT_ID", + "client_secret": "$MCP_OAUTH_CLIENT_SECRET", + "scope": "mcp.read", + "refresh_skew_seconds": 60 + } + } + } +} +``` + ## How It Works MCP servers expose tools that are automatically discovered and integrated into DeerFlow’s agent system at runtime. Once enabled, these tools become available to agents without additional code changes. diff --git a/backend/src/config/extensions_config.py b/backend/src/config/extensions_config.py index 8b8f3e4..be32925 100644 --- a/backend/src/config/extensions_config.py +++ b/backend/src/config/extensions_config.py @@ -3,11 +3,34 @@ import json import os from pathlib import Path -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field +class McpOAuthConfig(BaseModel): + """OAuth configuration for an MCP server (HTTP/SSE transports).""" + + enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled") + token_url: str = Field(description="OAuth token endpoint URL") + grant_type: Literal["client_credentials", "refresh_token"] = Field( + default="client_credentials", + description="OAuth grant type", + ) + client_id: str | None = Field(default=None, description="OAuth client ID") + client_secret: str | None = Field(default=None, description="OAuth client secret") + refresh_token: str | None = Field(default=None, description="OAuth refresh token (for refresh_token grant)") + scope: str | None = Field(default=None, description="OAuth scope") + audience: str | None = Field(default=None, description="OAuth audience (provider-specific)") + token_field: str = Field(default="access_token", description="Field name containing access token in token response") + token_type_field: str = Field(default="token_type", description="Field name containing token type in token response") + expires_in_field: str = Field(default="expires_in", description="Field name containing expiry (seconds) in token response") + default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response") + refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry") + extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint") + model_config = ConfigDict(extra="allow") + + class McpServerConfig(BaseModel): """Configuration for a single MCP server.""" @@ -18,6 +41,7 @@ class McpServerConfig(BaseModel): env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server") url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)") headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)") + oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)") description: str = Field(default="", description="Human-readable description of what this MCP server provides") model_config = ConfigDict(extra="allow") diff --git a/backend/src/gateway/routers/mcp.py b/backend/src/gateway/routers/mcp.py index f53b47e..60efc2c 100644 --- a/backend/src/gateway/routers/mcp.py +++ b/backend/src/gateway/routers/mcp.py @@ -1,6 +1,7 @@ import json import logging from pathlib import Path +from typing import Literal from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field @@ -11,6 +12,25 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["mcp"]) +class McpOAuthConfigResponse(BaseModel): + """OAuth configuration for an MCP server.""" + + enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled") + token_url: str = Field(default="", description="OAuth token endpoint URL") + grant_type: Literal["client_credentials", "refresh_token"] = Field(default="client_credentials", description="OAuth grant type") + client_id: str | None = Field(default=None, description="OAuth client ID") + client_secret: str | None = Field(default=None, description="OAuth client secret") + refresh_token: str | None = Field(default=None, description="OAuth refresh token") + scope: str | None = Field(default=None, description="OAuth scope") + audience: str | None = Field(default=None, description="OAuth audience") + token_field: str = Field(default="access_token", description="Token response field containing access token") + token_type_field: str = Field(default="token_type", description="Token response field containing token type") + expires_in_field: str = Field(default="expires_in", description="Token response field containing expires-in seconds") + default_token_type: str = Field(default="Bearer", description="Default token type when response omits token_type") + refresh_skew_seconds: int = Field(default=60, description="Refresh this many seconds before expiry") + extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint") + + class McpServerConfigResponse(BaseModel): """Response model for MCP server configuration.""" @@ -21,6 +41,7 @@ class McpServerConfigResponse(BaseModel): env: dict[str, str] = Field(default_factory=dict, description="Environment variables for the MCP server") url: str | None = Field(default=None, description="URL of the MCP server (for sse or http type)") headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)") + oauth: McpOAuthConfigResponse | None = Field(default=None, description="OAuth configuration for MCP HTTP/SSE servers") description: str = Field(default="", description="Human-readable description of what this MCP server provides") diff --git a/backend/src/mcp/oauth.py b/backend/src/mcp/oauth.py new file mode 100644 index 0000000..44d5a04 --- /dev/null +++ b/backend/src/mcp/oauth.py @@ -0,0 +1,150 @@ +"""OAuth token support for MCP HTTP/SSE servers.""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Any + +from src.config.extensions_config import ExtensionsConfig, McpOAuthConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class _OAuthToken: + """Cached OAuth token.""" + + access_token: str + token_type: str + expires_at: datetime + + +class OAuthTokenManager: + """Acquire/cache/refresh OAuth tokens for MCP servers.""" + + def __init__(self, oauth_by_server: dict[str, McpOAuthConfig]): + self._oauth_by_server = oauth_by_server + self._tokens: dict[str, _OAuthToken] = {} + self._locks: dict[str, asyncio.Lock] = {name: asyncio.Lock() for name in oauth_by_server} + + @classmethod + def from_extensions_config(cls, extensions_config: ExtensionsConfig) -> OAuthTokenManager: + oauth_by_server: dict[str, McpOAuthConfig] = {} + for server_name, server_config in extensions_config.get_enabled_mcp_servers().items(): + if server_config.oauth and server_config.oauth.enabled: + oauth_by_server[server_name] = server_config.oauth + return cls(oauth_by_server) + + def has_oauth_servers(self) -> bool: + return bool(self._oauth_by_server) + + def oauth_server_names(self) -> list[str]: + return list(self._oauth_by_server.keys()) + + async def get_authorization_header(self, server_name: str) -> str | None: + oauth = self._oauth_by_server.get(server_name) + if not oauth: + return None + + token = self._tokens.get(server_name) + if token and not self._is_expiring(token, oauth): + return f"{token.token_type} {token.access_token}" + + lock = self._locks[server_name] + async with lock: + token = self._tokens.get(server_name) + if token and not self._is_expiring(token, oauth): + return f"{token.token_type} {token.access_token}" + + fresh = await self._fetch_token(oauth) + self._tokens[server_name] = fresh + logger.info(f"Refreshed OAuth access token for MCP server: {server_name}") + return f"{fresh.token_type} {fresh.access_token}" + + @staticmethod + def _is_expiring(token: _OAuthToken, oauth: McpOAuthConfig) -> bool: + now = datetime.now(UTC) + return token.expires_at <= now + timedelta(seconds=max(oauth.refresh_skew_seconds, 0)) + + async def _fetch_token(self, oauth: McpOAuthConfig) -> _OAuthToken: + import httpx # pyright: ignore[reportMissingImports] + + data: dict[str, str] = { + "grant_type": oauth.grant_type, + **oauth.extra_token_params, + } + + if oauth.scope: + data["scope"] = oauth.scope + if oauth.audience: + data["audience"] = oauth.audience + + if oauth.grant_type == "client_credentials": + if not oauth.client_id or not oauth.client_secret: + raise ValueError("OAuth client_credentials requires client_id and client_secret") + data["client_id"] = oauth.client_id + data["client_secret"] = oauth.client_secret + elif oauth.grant_type == "refresh_token": + if not oauth.refresh_token: + raise ValueError("OAuth refresh_token grant requires refresh_token") + data["refresh_token"] = oauth.refresh_token + if oauth.client_id: + data["client_id"] = oauth.client_id + if oauth.client_secret: + data["client_secret"] = oauth.client_secret + else: + raise ValueError(f"Unsupported OAuth grant type: {oauth.grant_type}") + + async with httpx.AsyncClient(timeout=15.0) as client: + response = await client.post(oauth.token_url, data=data) + response.raise_for_status() + payload = response.json() + + access_token = payload.get(oauth.token_field) + if not access_token: + raise ValueError(f"OAuth token response missing '{oauth.token_field}'") + + token_type = str(payload.get(oauth.token_type_field, oauth.default_token_type) or oauth.default_token_type) + + expires_in_raw = payload.get(oauth.expires_in_field, 3600) + try: + expires_in = int(expires_in_raw) + except (TypeError, ValueError): + expires_in = 3600 + + expires_at = datetime.now(UTC) + timedelta(seconds=max(expires_in, 1)) + return _OAuthToken(access_token=access_token, token_type=token_type, expires_at=expires_at) + + +def build_oauth_tool_interceptor(extensions_config: ExtensionsConfig) -> Any | None: + """Build a tool interceptor that injects OAuth Authorization headers.""" + token_manager = OAuthTokenManager.from_extensions_config(extensions_config) + if not token_manager.has_oauth_servers(): + return None + + async def oauth_interceptor(request: Any, handler: Any) -> Any: + header = await token_manager.get_authorization_header(request.server_name) + if not header: + return await handler(request) + + updated_headers = dict(request.headers or {}) + updated_headers["Authorization"] = header + return await handler(request.override(headers=updated_headers)) + + return oauth_interceptor + + +async def get_initial_oauth_headers(extensions_config: ExtensionsConfig) -> dict[str, str]: + """Get initial OAuth Authorization headers for MCP server connections.""" + token_manager = OAuthTokenManager.from_extensions_config(extensions_config) + if not token_manager.has_oauth_servers(): + return {} + + headers: dict[str, str] = {} + for server_name in token_manager.oauth_server_names(): + headers[server_name] = await token_manager.get_authorization_header(server_name) or "" + + return {name: value for name, value in headers.items() if value} diff --git a/backend/src/mcp/tools.py b/backend/src/mcp/tools.py index 9f9889c..cb74029 100644 --- a/backend/src/mcp/tools.py +++ b/backend/src/mcp/tools.py @@ -6,6 +6,7 @@ from langchain_core.tools import BaseTool from src.config.extensions_config import ExtensionsConfig from src.mcp.client import build_servers_config +from src.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers logger = logging.getLogger(__name__) @@ -36,7 +37,23 @@ async def get_mcp_tools() -> list[BaseTool]: try: # Create the multi-server MCP client logger.info(f"Initializing MCP client with {len(servers_config)} server(s)") - client = MultiServerMCPClient(servers_config) + + # Inject initial OAuth headers for server connections (tool discovery/session init) + initial_oauth_headers = await get_initial_oauth_headers(extensions_config) + for server_name, auth_header in initial_oauth_headers.items(): + if server_name not in servers_config: + continue + if servers_config[server_name].get("transport") in ("sse", "http"): + existing_headers = dict(servers_config[server_name].get("headers", {})) + existing_headers["Authorization"] = auth_header + servers_config[server_name]["headers"] = existing_headers + + tool_interceptors = [] + oauth_interceptor = build_oauth_tool_interceptor(extensions_config) + if oauth_interceptor is not None: + tool_interceptors.append(oauth_interceptor) + + client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors) # Get all tools from all servers tools = await client.get_tools() diff --git a/backend/tests/test_mcp_oauth.py b/backend/tests/test_mcp_oauth.py new file mode 100644 index 0000000..b89ef03 --- /dev/null +++ b/backend/tests/test_mcp_oauth.py @@ -0,0 +1,191 @@ +"""Tests for MCP OAuth support.""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from src.config.extensions_config import ExtensionsConfig +from src.mcp.oauth import OAuthTokenManager, build_oauth_tool_interceptor, get_initial_oauth_headers + + +class _MockResponse: + def __init__(self, payload: dict[str, Any]): + self._payload = payload + + def raise_for_status(self) -> None: + return None + + def json(self) -> dict[str, Any]: + return self._payload + + +class _MockAsyncClient: + def __init__(self, payload: dict[str, Any], post_calls: list[dict[str, Any]], **kwargs): + self._payload = payload + self._post_calls = post_calls + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url: str, data: dict[str, Any]): + self._post_calls.append({"url": url, "data": data}) + return _MockResponse(self._payload) + + +def test_oauth_token_manager_fetches_and_caches_token(monkeypatch): + post_calls: list[dict[str, Any]] = [] + + def _client_factory(*args, **kwargs): + return _MockAsyncClient( + payload={ + "access_token": "token-123", + "token_type": "Bearer", + "expires_in": 3600, + }, + post_calls=post_calls, + **kwargs, + ) + + monkeypatch.setattr("httpx.AsyncClient", _client_factory) + + config = ExtensionsConfig.model_validate( + { + "mcpServers": { + "secure-http": { + "enabled": True, + "type": "http", + "url": "https://api.example.com/mcp", + "oauth": { + "enabled": True, + "token_url": "https://auth.example.com/oauth/token", + "grant_type": "client_credentials", + "client_id": "client-id", + "client_secret": "client-secret", + }, + } + } + } + ) + + manager = OAuthTokenManager.from_extensions_config(config) + + first = asyncio.run(manager.get_authorization_header("secure-http")) + second = asyncio.run(manager.get_authorization_header("secure-http")) + + assert first == "Bearer token-123" + assert second == "Bearer token-123" + assert len(post_calls) == 1 + assert post_calls[0]["url"] == "https://auth.example.com/oauth/token" + assert post_calls[0]["data"]["grant_type"] == "client_credentials" + + +def test_build_oauth_interceptor_injects_authorization_header(monkeypatch): + post_calls: list[dict[str, Any]] = [] + + def _client_factory(*args, **kwargs): + return _MockAsyncClient( + payload={ + "access_token": "token-abc", + "token_type": "Bearer", + "expires_in": 3600, + }, + post_calls=post_calls, + **kwargs, + ) + + monkeypatch.setattr("httpx.AsyncClient", _client_factory) + + config = ExtensionsConfig.model_validate( + { + "mcpServers": { + "secure-sse": { + "enabled": True, + "type": "sse", + "url": "https://api.example.com/mcp", + "oauth": { + "enabled": True, + "token_url": "https://auth.example.com/oauth/token", + "grant_type": "client_credentials", + "client_id": "client-id", + "client_secret": "client-secret", + }, + } + } + } + ) + + interceptor = build_oauth_tool_interceptor(config) + assert interceptor is not None + + class _Request: + def __init__(self): + self.server_name = "secure-sse" + self.headers = {"X-Test": "1"} + + def override(self, **kwargs): + updated = _Request() + updated.server_name = self.server_name + updated.headers = kwargs.get("headers") + return updated + + captured: dict[str, Any] = {} + + async def _handler(request): + captured["headers"] = request.headers + return "ok" + + result = asyncio.run(interceptor(_Request(), _handler)) + + assert result == "ok" + assert captured["headers"]["Authorization"] == "Bearer token-abc" + assert captured["headers"]["X-Test"] == "1" + + +def test_get_initial_oauth_headers(monkeypatch): + post_calls: list[dict[str, Any]] = [] + + def _client_factory(*args, **kwargs): + return _MockAsyncClient( + payload={ + "access_token": "token-initial", + "token_type": "Bearer", + "expires_in": 3600, + }, + post_calls=post_calls, + **kwargs, + ) + + monkeypatch.setattr("httpx.AsyncClient", _client_factory) + + config = ExtensionsConfig.model_validate( + { + "mcpServers": { + "secure-http": { + "enabled": True, + "type": "http", + "url": "https://api.example.com/mcp", + "oauth": { + "enabled": True, + "token_url": "https://auth.example.com/oauth/token", + "grant_type": "client_credentials", + "client_id": "client-id", + "client_secret": "client-secret", + }, + }, + "no-oauth": { + "enabled": True, + "type": "http", + "url": "https://example.com/mcp", + }, + } + } + ) + + headers = asyncio.run(get_initial_oauth_headers(config)) + + assert headers == {"secure-http": "Bearer token-initial"} + assert len(post_calls) == 1 diff --git a/extensions_config.example.json b/extensions_config.example.json index 567610b..833ef3b 100644 --- a/extensions_config.example.json +++ b/extensions_config.example.json @@ -32,7 +32,17 @@ "headers": { "Authorization": "Bearer $API_TOKEN", "X-Custom-Header": "value" - } + }, + "oauth": { + "enabled": true, + "token_url": "https://auth.example.com/oauth/token", + "grant_type": "client_credentials", + "client_id": "$MCP_OAUTH_CLIENT_ID", + "client_secret": "$MCP_OAUTH_CLIENT_SECRET", + "scope": "mcp.read mcp.write", + "audience": "https://api.example.com", + "refresh_skew_seconds": 60 + } }, "my-http-server": { "type": "http", @@ -40,7 +50,14 @@ "headers": { "Authorization": "Bearer $API_TOKEN", "X-Custom-Header": "value" - } + }, + "oauth": { + "enabled": true, + "token_url": "https://auth.example.com/oauth/token", + "grant_type": "client_credentials", + "client_id": "$MCP_OAUTH_CLIENT_ID", + "client_secret": "$MCP_OAUTH_CLIENT_SECRET" + } } }, "skills": {