mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 14:22:13 +08:00
feat(server): add MCP server configuration validation (#830)
* feat(server): add MCP server configuration validation Add comprehensive validation for MCP server configurations, inspired by Flowise's validateMCPServerConfig implementation. MCPServerConfig checks implemented: - Command allowlist validation (node, npx, python, docker, uvx, etc.) - Path traversal prevention (blocks ../, absolute paths, ~/) - Shell command injection prevention (blocks ; & | ` $ etc.) - Dangerous environment variable blocking (PATH, LD_PRELOAD, etc.) - URL validation for SSE/HTTP transports (scheme, credentials) - HTTP header injection prevention (blocks newlines) * fix the unit test error of test_chat_request * Added the related path cases as reviewer commented * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,17 @@
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from src.server.mcp_validators import (
|
||||
MCPValidationError,
|
||||
validate_args_for_local_file_access,
|
||||
validate_command,
|
||||
validate_command_injection,
|
||||
validate_environment_variables,
|
||||
validate_headers,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
class MCPServerMetadataRequest(BaseModel):
|
||||
@@ -43,6 +53,62 @@ class MCPServerMetadataRequest(BaseModel):
|
||||
description="Optional SSE read timeout in seconds (for sse type, default: 30, range: 1-3600)"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_security(self) -> "MCPServerMetadataRequest":
|
||||
"""Validate MCP server configuration for security issues."""
|
||||
errors: List[str] = []
|
||||
|
||||
# Validate transport type
|
||||
valid_transports = {"stdio", "sse", "streamable_http"}
|
||||
if self.transport not in valid_transports:
|
||||
errors.append(
|
||||
f"Invalid transport type: {self.transport}. Must be one of: {', '.join(valid_transports)}"
|
||||
)
|
||||
|
||||
# Validate stdio-specific fields
|
||||
if self.transport == "stdio":
|
||||
if self.command:
|
||||
try:
|
||||
validate_command(self.command)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
if self.args:
|
||||
try:
|
||||
validate_args_for_local_file_access(self.args)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
try:
|
||||
validate_command_injection(self.args)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
if self.env:
|
||||
try:
|
||||
validate_environment_variables(self.env)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
# Validate SSE/HTTP-specific fields
|
||||
elif self.transport in ("sse", "streamable_http"):
|
||||
if self.url:
|
||||
try:
|
||||
validate_url(self.url)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
if self.headers:
|
||||
try:
|
||||
validate_headers(self.headers)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class MCPServerMetadataResponse(BaseModel):
|
||||
"""Response model for MCP server metadata."""
|
||||
|
||||
@@ -11,6 +11,8 @@ from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from src.server.mcp_validators import MCPValidationError, validate_mcp_server_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -75,6 +77,9 @@ async def load_mcp_tools(
|
||||
Raises:
|
||||
HTTPException: If there's an error loading the tools
|
||||
"""
|
||||
# MCP server configuration is validated at the request boundary (Pydantic model)
|
||||
# to avoid duplicate validation logic here.
|
||||
|
||||
try:
|
||||
if server_type == "stdio":
|
||||
if not command:
|
||||
|
||||
532
src/server/mcp_validators.py
Normal file
532
src/server/mcp_validators.py
Normal file
@@ -0,0 +1,532 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
MCP Server Configuration Validators.
|
||||
|
||||
This module provides security validation for MCP server configurations,
|
||||
inspired by Flowise's validateMCPServerConfig implementation. It prevents:
|
||||
- Command injection attacks
|
||||
- Path traversal attacks
|
||||
- Unauthorized file access
|
||||
- Dangerous environment variable modifications
|
||||
|
||||
Reference: https://github.com/FlowiseAI/Flowise/blob/main/packages/components/nodes/tools/MCP/core.ts
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPValidationError(Exception):
|
||||
"""Exception raised when MCP server configuration validation fails."""
|
||||
|
||||
def __init__(self, message: str, field: Optional[str] = None):
|
||||
self.message = message
|
||||
self.field = field
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
# Allowed commands for stdio transport
|
||||
# These are considered safe executable commands for MCP servers
|
||||
ALLOWED_COMMANDS = frozenset([
|
||||
"node",
|
||||
"npx",
|
||||
"python",
|
||||
"python3",
|
||||
"docker",
|
||||
"uvx",
|
||||
"uv",
|
||||
"deno",
|
||||
"bun",
|
||||
])
|
||||
|
||||
# Dangerous environment variables that should not be modified
|
||||
DANGEROUS_ENV_VARS = frozenset([
|
||||
"PATH",
|
||||
"LD_LIBRARY_PATH",
|
||||
"DYLD_LIBRARY_PATH",
|
||||
"LD_PRELOAD",
|
||||
"DYLD_INSERT_LIBRARIES",
|
||||
"PYTHONPATH",
|
||||
"NODE_PATH",
|
||||
"RUBYLIB",
|
||||
"PERL5LIB",
|
||||
])
|
||||
|
||||
# Shell metacharacters that could be used for injection
|
||||
SHELL_METACHARACTERS = frozenset([
|
||||
";",
|
||||
"&",
|
||||
"|",
|
||||
"`",
|
||||
"$",
|
||||
"(",
|
||||
")",
|
||||
"{",
|
||||
"}",
|
||||
"[",
|
||||
"]",
|
||||
"<",
|
||||
">",
|
||||
"\n",
|
||||
"\r",
|
||||
])
|
||||
|
||||
# Dangerous file extensions that should not be directly accessed
|
||||
DANGEROUS_EXTENSIONS = frozenset([
|
||||
".exe",
|
||||
".dll",
|
||||
".so",
|
||||
".dylib",
|
||||
".bat",
|
||||
".cmd",
|
||||
".ps1",
|
||||
".sh",
|
||||
".bash",
|
||||
".zsh",
|
||||
".env",
|
||||
".pem",
|
||||
".key",
|
||||
".crt",
|
||||
".p12",
|
||||
".pfx",
|
||||
])
|
||||
|
||||
# Command chaining patterns
|
||||
COMMAND_CHAINING_PATTERNS = [
|
||||
"&&",
|
||||
"||",
|
||||
";;",
|
||||
">>",
|
||||
"<<",
|
||||
"$(",
|
||||
"<(",
|
||||
">(",
|
||||
]
|
||||
|
||||
# Maximum argument length to prevent buffer overflow attacks
|
||||
MAX_ARG_LENGTH = 1000
|
||||
|
||||
# Allowed URL schemes for SSE/HTTP transports
|
||||
ALLOWED_URL_SCHEMES = frozenset(["http", "https"])
|
||||
|
||||
|
||||
def validate_mcp_server_config(
|
||||
transport: str,
|
||||
command: Optional[str] = None,
|
||||
args: Optional[List[str]] = None,
|
||||
url: Optional[str] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Validate MCP server configuration for security issues.
|
||||
|
||||
This is the main entry point for MCP server validation. It orchestrates
|
||||
all security checks based on the transport type.
|
||||
|
||||
Args:
|
||||
transport: The type of MCP connection (stdio, sse, streamable_http)
|
||||
command: The command to execute (for stdio transport)
|
||||
args: Command arguments (for stdio transport)
|
||||
url: The URL of the server (for sse/streamable_http transport)
|
||||
env: Environment variables (for stdio transport)
|
||||
headers: HTTP headers (for sse/streamable_http transport)
|
||||
strict: If True, raise exceptions; if False, log warnings only
|
||||
|
||||
Raises:
|
||||
MCPValidationError: If validation fails in strict mode
|
||||
"""
|
||||
errors: List[str] = []
|
||||
|
||||
# Validate transport type
|
||||
valid_transports = {"stdio", "sse", "streamable_http"}
|
||||
if transport not in valid_transports:
|
||||
errors.append(f"Invalid transport type: {transport}. Must be one of: {', '.join(valid_transports)}")
|
||||
|
||||
# Transport-specific validation
|
||||
if transport == "stdio":
|
||||
# Validate command
|
||||
if command:
|
||||
try:
|
||||
validate_command(command)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
# Validate arguments
|
||||
if args:
|
||||
try:
|
||||
validate_args_for_local_file_access(args)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
try:
|
||||
validate_command_injection(args)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
# Validate environment variables
|
||||
if env:
|
||||
try:
|
||||
validate_environment_variables(env)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
elif transport in ("sse", "streamable_http"):
|
||||
# Validate URL
|
||||
if url:
|
||||
try:
|
||||
validate_url(url)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
# Validate headers for injection
|
||||
if headers:
|
||||
try:
|
||||
validate_headers(headers)
|
||||
except MCPValidationError as e:
|
||||
errors.append(e.message)
|
||||
|
||||
# Handle errors
|
||||
if errors:
|
||||
error_message = "; ".join(errors)
|
||||
if strict:
|
||||
raise MCPValidationError(error_message)
|
||||
else:
|
||||
logger.warning(f"MCP configuration validation warnings: {error_message}")
|
||||
|
||||
|
||||
def validate_command(command: str) -> None:
|
||||
"""
|
||||
Validate the command against an allowlist of safe executables.
|
||||
|
||||
Args:
|
||||
command: The command to validate
|
||||
|
||||
Raises:
|
||||
MCPValidationError: If the command is not in the allowlist
|
||||
"""
|
||||
if not command or not isinstance(command, str):
|
||||
raise MCPValidationError("Command must be a non-empty string", field="command")
|
||||
|
||||
# Extract the base command (handle full paths)
|
||||
# e.g., "/usr/bin/python3" -> "python3"
|
||||
base_command = command.split("/")[-1].split("\\")[-1]
|
||||
|
||||
# Also handle .exe suffix on Windows
|
||||
if base_command.endswith(".exe"):
|
||||
base_command = base_command[:-4]
|
||||
|
||||
# Normalize to lowercase to handle case-insensitive filesystems (e.g., Windows)
|
||||
normalized_command = base_command.lower()
|
||||
|
||||
if normalized_command not in ALLOWED_COMMANDS:
|
||||
raise MCPValidationError(
|
||||
f"Command '{command}' is not allowed. Allowed commands: {', '.join(sorted(ALLOWED_COMMANDS))}",
|
||||
field="command",
|
||||
)
|
||||
|
||||
|
||||
def validate_args_for_local_file_access(args: List[str]) -> None:
|
||||
"""
|
||||
Validate arguments to prevent path traversal and unauthorized file access.
|
||||
|
||||
Checks for:
|
||||
- Absolute paths (starting with / or drive letters like C:)
|
||||
- Directory traversal (../, ..\\)
|
||||
- Local file access patterns (./, ~/)
|
||||
- Dangerous file extensions
|
||||
- Null bytes (security exploit)
|
||||
- Excessively long arguments (buffer overflow protection)
|
||||
|
||||
Args:
|
||||
args: List of command arguments to validate
|
||||
|
||||
Raises:
|
||||
MCPValidationError: If any argument contains dangerous patterns
|
||||
"""
|
||||
if not args:
|
||||
return
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
if not isinstance(arg, str):
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} must be a string, got {type(arg).__name__}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for excessively long arguments
|
||||
if len(arg) > MAX_ARG_LENGTH:
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} exceeds maximum length of {MAX_ARG_LENGTH} characters",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in arg:
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} contains null byte",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for directory traversal
|
||||
if ".." in arg:
|
||||
# More specific check for actual traversal patterns
|
||||
# Catches: "../", "..\", "/..", "\..", standalone "..", starts with "..", ends with ".."
|
||||
if (
|
||||
"../" in arg
|
||||
or "..\\" in arg
|
||||
or "/.." in arg
|
||||
or "\\.." in arg
|
||||
or arg == ".."
|
||||
or arg.startswith("..")
|
||||
or arg.endswith("..")
|
||||
):
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} contains directory traversal pattern: {arg[:50]}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for absolute paths (Unix-style)
|
||||
# Be careful to allow flags like -f, --flag, etc. (e.g. "/-f").
|
||||
# We reject all absolute Unix paths (including single-component ones like "/etc")
|
||||
# to avoid access to potentially sensitive directories.
|
||||
if arg.startswith("/") and not arg.startswith("/-"):
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} contains absolute path: {arg[:50]}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for Windows absolute paths
|
||||
if len(arg) >= 2 and arg[1] == ":" and arg[0].isalpha():
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} contains Windows absolute path: {arg[:50]}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for home directory expansion
|
||||
if arg.startswith("~/") or arg.startswith("~\\"):
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} contains home directory reference: {arg[:50]}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for dangerous extensions in the argument
|
||||
arg_lower = arg.lower()
|
||||
for ext in DANGEROUS_EXTENSIONS:
|
||||
if arg_lower.endswith(ext):
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} references potentially dangerous file type: {ext}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
|
||||
def validate_command_injection(args: List[str]) -> None:
|
||||
"""
|
||||
Validate arguments to prevent shell command injection.
|
||||
|
||||
Checks for:
|
||||
- Shell metacharacters (; & | ` $ ( ) { } [ ] < > etc.)
|
||||
- Command chaining patterns (&& || ;; etc.)
|
||||
- Command substitution patterns ($() ``)
|
||||
- Process substitution patterns (<() >())
|
||||
|
||||
Args:
|
||||
args: List of command arguments to validate
|
||||
|
||||
Raises:
|
||||
MCPValidationError: If any argument contains injection patterns
|
||||
"""
|
||||
if not args:
|
||||
return
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
if not isinstance(arg, str):
|
||||
continue
|
||||
|
||||
# Check for shell metacharacters
|
||||
for char in SHELL_METACHARACTERS:
|
||||
if char in arg:
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} contains shell metacharacter '{char}': {arg[:50]}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
# Check for command chaining patterns
|
||||
for pattern in COMMAND_CHAINING_PATTERNS:
|
||||
if pattern in arg:
|
||||
raise MCPValidationError(
|
||||
f"Argument at index {i} contains command chaining pattern '{pattern}': {arg[:50]}",
|
||||
field="args",
|
||||
)
|
||||
|
||||
|
||||
def validate_environment_variables(env: Dict[str, str]) -> None:
|
||||
"""
|
||||
Validate environment variables to prevent dangerous modifications.
|
||||
|
||||
Checks for:
|
||||
- Modifications to PATH and library path variables
|
||||
- Null bytes in values
|
||||
- Excessively long values
|
||||
|
||||
Args:
|
||||
env: Dictionary of environment variables
|
||||
|
||||
Raises:
|
||||
MCPValidationError: If any environment variable is dangerous
|
||||
"""
|
||||
if not env:
|
||||
return
|
||||
|
||||
if not isinstance(env, dict):
|
||||
raise MCPValidationError(
|
||||
f"Environment variables must be a dictionary, got {type(env).__name__}",
|
||||
field="env",
|
||||
)
|
||||
|
||||
for key, value in env.items():
|
||||
# Validate key
|
||||
if not isinstance(key, str):
|
||||
raise MCPValidationError(
|
||||
f"Environment variable key must be a string, got {type(key).__name__}",
|
||||
field="env",
|
||||
)
|
||||
|
||||
# Check for dangerous environment variables
|
||||
if key.upper() in DANGEROUS_ENV_VARS:
|
||||
raise MCPValidationError(
|
||||
f"Modification of environment variable '{key}' is not allowed for security reasons",
|
||||
field="env",
|
||||
)
|
||||
|
||||
# Validate value
|
||||
if not isinstance(value, str):
|
||||
raise MCPValidationError(
|
||||
f"Environment variable value for '{key}' must be a string, got {type(value).__name__}",
|
||||
field="env",
|
||||
)
|
||||
|
||||
# Check for null bytes in value
|
||||
if "\x00" in value:
|
||||
raise MCPValidationError(
|
||||
f"Environment variable '{key}' contains null byte",
|
||||
field="env",
|
||||
)
|
||||
|
||||
# Check for excessively long values
|
||||
if len(value) > MAX_ARG_LENGTH * 10: # Allow longer env values
|
||||
raise MCPValidationError(
|
||||
f"Environment variable '{key}' value exceeds maximum length",
|
||||
field="env",
|
||||
)
|
||||
|
||||
|
||||
def validate_url(url: str) -> None:
|
||||
"""
|
||||
Validate URL for SSE/HTTP transport.
|
||||
|
||||
Checks for:
|
||||
- Valid URL format
|
||||
- Allowed schemes (http, https)
|
||||
- No credentials in URL
|
||||
- No localhost/internal network access (optional, configurable)
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
|
||||
Raises:
|
||||
MCPValidationError: If the URL is invalid or potentially dangerous
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
raise MCPValidationError("URL must be a non-empty string", field="url")
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in url:
|
||||
raise MCPValidationError("URL contains null byte", field="url")
|
||||
|
||||
# Parse the URL
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except Exception as e:
|
||||
raise MCPValidationError(f"Invalid URL format: {e}", field="url")
|
||||
|
||||
# Check scheme
|
||||
if parsed.scheme not in ALLOWED_URL_SCHEMES:
|
||||
raise MCPValidationError(
|
||||
f"URL scheme '{parsed.scheme}' is not allowed. Allowed schemes: {', '.join(ALLOWED_URL_SCHEMES)}",
|
||||
field="url",
|
||||
)
|
||||
|
||||
# Check for credentials in URL (security risk)
|
||||
if parsed.username or parsed.password:
|
||||
raise MCPValidationError(
|
||||
"URL should not contain credentials. Use headers for authentication instead.",
|
||||
field="url",
|
||||
)
|
||||
|
||||
# Check for valid host
|
||||
if not parsed.netloc:
|
||||
raise MCPValidationError("URL must have a valid host", field="url")
|
||||
|
||||
|
||||
def validate_headers(headers: Dict[str, str]) -> None:
|
||||
"""
|
||||
Validate HTTP headers for potential injection attacks.
|
||||
|
||||
Args:
|
||||
headers: Dictionary of HTTP headers
|
||||
|
||||
Raises:
|
||||
MCPValidationError: If any header contains dangerous patterns
|
||||
"""
|
||||
if not headers:
|
||||
return
|
||||
|
||||
if not isinstance(headers, dict):
|
||||
raise MCPValidationError(
|
||||
f"Headers must be a dictionary, got {type(headers).__name__}",
|
||||
field="headers",
|
||||
)
|
||||
|
||||
for key, value in headers.items():
|
||||
# Validate key
|
||||
if not isinstance(key, str):
|
||||
raise MCPValidationError(
|
||||
f"Header key must be a string, got {type(key).__name__}",
|
||||
field="headers",
|
||||
)
|
||||
|
||||
# Check for newlines in header name (HTTP header injection)
|
||||
if "\n" in key or "\r" in key:
|
||||
raise MCPValidationError(
|
||||
f"Header name '{key[:20]}' contains newline character (potential HTTP header injection)",
|
||||
field="headers",
|
||||
)
|
||||
|
||||
# Validate value
|
||||
if not isinstance(value, str):
|
||||
raise MCPValidationError(
|
||||
f"Header value for '{key}' must be a string, got {type(value).__name__}",
|
||||
field="headers",
|
||||
)
|
||||
|
||||
# Check for newlines in header value (HTTP header injection)
|
||||
if "\n" in value or "\r" in value:
|
||||
raise MCPValidationError(
|
||||
f"Header value for '{key}' contains newline character (potential HTTP header injection)",
|
||||
field="headers",
|
||||
)
|
||||
|
||||
# Check for null bytes
|
||||
if "\x00" in key or "\x00" in value:
|
||||
raise MCPValidationError(
|
||||
f"Header '{key}' contains null byte",
|
||||
field="headers",
|
||||
)
|
||||
@@ -352,9 +352,9 @@ class TestMCPEndpoint:
|
||||
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {"ENV_VAR": "value"},
|
||||
"command": "node",
|
||||
"args": ["server.js"],
|
||||
"env": {"API_KEY": "test123"},
|
||||
}
|
||||
|
||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||
@@ -362,7 +362,7 @@ class TestMCPEndpoint:
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["transport"] == "stdio"
|
||||
assert response_data["command"] == "test_command"
|
||||
assert response_data["command"] == "node"
|
||||
assert len(response_data["tools"]) == 1
|
||||
|
||||
@patch("src.server.app.load_mcp_tools")
|
||||
@@ -375,7 +375,7 @@ class TestMCPEndpoint:
|
||||
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"command": "node",
|
||||
"timeout_seconds": 60,
|
||||
}
|
||||
|
||||
@@ -424,9 +424,9 @@ class TestMCPEndpoint:
|
||||
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {"ENV_VAR": "value"},
|
||||
"command": "node",
|
||||
"args": ["server.js"],
|
||||
"env": {"API_KEY": "test123"},
|
||||
}
|
||||
|
||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||
@@ -444,9 +444,9 @@ class TestMCPEndpoint:
|
||||
):
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {"ENV_VAR": "value"},
|
||||
"command": "node",
|
||||
"args": ["server.js"],
|
||||
"env": {"API_KEY": "test123"},
|
||||
}
|
||||
|
||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||
|
||||
@@ -163,6 +163,6 @@ async def test_load_mcp_tools_exception_handling(
|
||||
mock_stdio_client.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") # Use await
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="node") # Use await
|
||||
assert exc.value.status_code == 500
|
||||
assert "unexpected error" in exc.value.detail
|
||||
|
||||
@@ -55,14 +55,14 @@ async def test_load_mcp_tools_stdio_success(
|
||||
|
||||
result = await mcp_utils.load_mcp_tools(
|
||||
server_type="stdio",
|
||||
command="echo",
|
||||
args=["foo"],
|
||||
env={"FOO": "BAR"},
|
||||
command="node",
|
||||
args=["server.js"],
|
||||
env={"API_KEY": "test123"},
|
||||
timeout_seconds=3,
|
||||
)
|
||||
assert result == ["toolA"]
|
||||
mock_StdioServerParameters.assert_called_once_with(
|
||||
command="echo", args=["foo"], env={"FOO": "BAR"}
|
||||
command="node", args=["server.js"], env={"API_KEY": "test123"}
|
||||
)
|
||||
mock_stdio_client.assert_called_once_with(params)
|
||||
mock_get_tools.assert_awaited_once_with(mock_client, 3)
|
||||
@@ -165,7 +165,7 @@ async def test_load_mcp_tools_unsupported_type():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="unknown")
|
||||
assert exc.value.status_code == 400
|
||||
assert "Unsupported server type" in exc.value.detail
|
||||
assert "Invalid transport type" in exc.value.detail or "Unsupported server type" in exc.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -180,6 +180,6 @@ async def test_load_mcp_tools_exception_handling(
|
||||
mock_stdio_client.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo")
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="node")
|
||||
assert exc.value.status_code == 500
|
||||
assert "unexpected error" in exc.value.detail
|
||||
|
||||
450
tests/unit/server/test_mcp_validators.py
Normal file
450
tests/unit/server/test_mcp_validators.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for MCP server configuration validators.
|
||||
|
||||
Tests cover:
|
||||
- Command validation (allowlist)
|
||||
- Argument validation (path traversal, command injection)
|
||||
- Environment variable validation
|
||||
- URL validation
|
||||
- Header validation
|
||||
- Full config validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.mcp_validators import (
|
||||
ALLOWED_COMMANDS,
|
||||
MCPValidationError,
|
||||
validate_args_for_local_file_access,
|
||||
validate_command,
|
||||
validate_command_injection,
|
||||
validate_environment_variables,
|
||||
validate_headers,
|
||||
validate_mcp_server_config,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateCommand:
|
||||
"""Tests for validate_command function."""
|
||||
|
||||
def test_allowed_commands(self):
|
||||
"""Test that all allowed commands pass validation."""
|
||||
for cmd in ALLOWED_COMMANDS:
|
||||
validate_command(cmd) # Should not raise
|
||||
|
||||
def test_allowed_command_with_path(self):
|
||||
"""Test that commands with paths are validated by base name."""
|
||||
validate_command("/usr/bin/python3")
|
||||
validate_command("/usr/local/bin/node")
|
||||
validate_command("C:\\Python\\python.exe")
|
||||
|
||||
def test_disallowed_command(self):
|
||||
"""Test that disallowed commands raise an error."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_command("bash")
|
||||
assert "not allowed" in exc_info.value.message
|
||||
assert exc_info.value.field == "command"
|
||||
|
||||
def test_disallowed_dangerous_commands(self):
|
||||
"""Test that dangerous commands are rejected."""
|
||||
dangerous_commands = ["rm", "sudo", "chmod", "chown", "curl", "wget", "sh"]
|
||||
for cmd in dangerous_commands:
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command(cmd)
|
||||
|
||||
def test_empty_command(self):
|
||||
"""Test that empty command raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command("")
|
||||
|
||||
def test_none_command(self):
|
||||
"""Test that None command raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command(None)
|
||||
|
||||
|
||||
class TestValidateArgsForLocalFileAccess:
|
||||
"""Tests for validate_args_for_local_file_access function."""
|
||||
|
||||
def test_safe_args(self):
|
||||
"""Test that safe arguments pass validation."""
|
||||
safe_args = [
|
||||
["--help"],
|
||||
["-v", "--verbose"],
|
||||
["package-name"],
|
||||
["--config", "config.json"],
|
||||
["run", "script.py"],
|
||||
]
|
||||
for args in safe_args:
|
||||
validate_args_for_local_file_access(args) # Should not raise
|
||||
|
||||
def test_directory_traversal(self):
|
||||
"""Test that directory traversal patterns are rejected."""
|
||||
traversal_patterns = [
|
||||
["../etc/passwd"],
|
||||
["..\\windows\\system32"],
|
||||
["../../secret"],
|
||||
["foo/../bar/../../../etc/passwd"],
|
||||
["foo/.."], # ".." at end after path separator
|
||||
["bar\\.."], # ".." at end after Windows path separator
|
||||
["path/to/foo/.."], # Longer path ending with ".."
|
||||
]
|
||||
for args in traversal_patterns:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(args)
|
||||
assert "traversal" in exc_info.value.message.lower()
|
||||
|
||||
def test_absolute_path_with_dangerous_extension(self):
|
||||
"""Test that absolute paths with dangerous extensions are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["/etc/passwd.sh"])
|
||||
|
||||
def test_windows_absolute_path(self):
|
||||
"""Test that Windows absolute paths are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["C:\\Windows\\system32"])
|
||||
|
||||
def test_home_directory_reference(self):
|
||||
"""Test that home directory references are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["~/secrets"])
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["~\\secrets"])
|
||||
|
||||
def test_null_byte(self):
|
||||
"""Test that null bytes in arguments are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(["file\x00.txt"])
|
||||
assert "null byte" in exc_info.value.message.lower()
|
||||
|
||||
def test_excessively_long_argument(self):
|
||||
"""Test that excessively long arguments are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(["a" * 1001])
|
||||
assert "maximum length" in exc_info.value.message.lower()
|
||||
|
||||
def test_dangerous_extensions(self):
|
||||
"""Test that dangerous file extensions are rejected."""
|
||||
dangerous_files = [
|
||||
["script.sh"],
|
||||
["binary.exe"],
|
||||
["library.dll"],
|
||||
["secret.env"],
|
||||
["key.pem"],
|
||||
]
|
||||
for args in dangerous_files:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(args)
|
||||
assert "dangerous file type" in exc_info.value.message.lower()
|
||||
|
||||
def test_empty_args(self):
|
||||
"""Test that empty args list passes validation."""
|
||||
validate_args_for_local_file_access([])
|
||||
validate_args_for_local_file_access(None)
|
||||
|
||||
|
||||
class TestValidateCommandInjection:
|
||||
"""Tests for validate_command_injection function."""
|
||||
|
||||
def test_safe_args(self):
|
||||
"""Test that safe arguments pass validation."""
|
||||
safe_args = [
|
||||
["--help"],
|
||||
["package-name"],
|
||||
["@scope/package"],
|
||||
["file.json"],
|
||||
]
|
||||
for args in safe_args:
|
||||
validate_command_injection(args) # Should not raise
|
||||
|
||||
def test_shell_metacharacters(self):
|
||||
"""Test that shell metacharacters are rejected."""
|
||||
metachar_args = [
|
||||
["foo; rm -rf /"],
|
||||
["foo & bar"],
|
||||
["foo | cat /etc/passwd"],
|
||||
["$(whoami)"],
|
||||
["`id`"],
|
||||
["foo > /etc/passwd"],
|
||||
["foo < /etc/passwd"],
|
||||
["${PATH}"],
|
||||
]
|
||||
for args in metachar_args:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_command_injection(args)
|
||||
assert "args" == exc_info.value.field
|
||||
|
||||
def test_command_chaining(self):
|
||||
"""Test that command chaining patterns are rejected."""
|
||||
chaining_args = [
|
||||
["foo && bar"],
|
||||
["foo || bar"],
|
||||
["foo;; bar"],
|
||||
["foo >> output"],
|
||||
["foo << input"],
|
||||
]
|
||||
for args in chaining_args:
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(args)
|
||||
|
||||
def test_backtick_injection(self):
|
||||
"""Test that backtick command substitution is rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(["`whoami`"])
|
||||
|
||||
def test_process_substitution(self):
|
||||
"""Test that process substitution is rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(["<(cat /etc/passwd)"])
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection([">(tee /tmp/out)"])
|
||||
|
||||
|
||||
class TestValidateEnvironmentVariables:
|
||||
"""Tests for validate_environment_variables function."""
|
||||
|
||||
def test_safe_env_vars(self):
|
||||
"""Test that safe environment variables pass validation."""
|
||||
safe_env = {
|
||||
"API_KEY": "secret123",
|
||||
"DEBUG": "true",
|
||||
"MY_VARIABLE": "value",
|
||||
}
|
||||
validate_environment_variables(safe_env) # Should not raise
|
||||
|
||||
def test_dangerous_env_vars(self):
|
||||
"""Test that dangerous environment variables are rejected."""
|
||||
dangerous_vars = [
|
||||
{"PATH": "/malicious/path"},
|
||||
{"LD_LIBRARY_PATH": "/malicious/lib"},
|
||||
{"DYLD_LIBRARY_PATH": "/malicious/lib"},
|
||||
{"LD_PRELOAD": "/malicious/lib.so"},
|
||||
{"PYTHONPATH": "/malicious/python"},
|
||||
{"NODE_PATH": "/malicious/node"},
|
||||
]
|
||||
for env in dangerous_vars:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_environment_variables(env)
|
||||
assert "not allowed" in exc_info.value.message.lower()
|
||||
|
||||
def test_null_byte_in_value(self):
|
||||
"""Test that null bytes in values are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_environment_variables({"KEY": "value\x00malicious"})
|
||||
|
||||
def test_empty_env(self):
|
||||
"""Test that empty env dict passes validation."""
|
||||
validate_environment_variables({})
|
||||
validate_environment_variables(None)
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Tests for validate_url function."""
|
||||
|
||||
def test_valid_urls(self):
|
||||
"""Test that valid URLs pass validation."""
|
||||
valid_urls = [
|
||||
"http://localhost:3000",
|
||||
"https://api.example.com",
|
||||
"http://192.168.1.1:8080/api",
|
||||
"https://mcp.example.com/sse",
|
||||
]
|
||||
for url in valid_urls:
|
||||
validate_url(url) # Should not raise
|
||||
|
||||
def test_invalid_scheme(self):
|
||||
"""Test that invalid URL schemes are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_url("ftp://example.com")
|
||||
assert "scheme" in exc_info.value.message.lower()
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("file:///etc/passwd")
|
||||
|
||||
def test_credentials_in_url(self):
|
||||
"""Test that URLs with credentials are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_url("https://user:pass@example.com")
|
||||
assert "credentials" in exc_info.value.message.lower()
|
||||
|
||||
def test_null_byte_in_url(self):
|
||||
"""Test that null bytes in URL are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("https://example.com\x00/malicious")
|
||||
|
||||
def test_empty_url(self):
|
||||
"""Test that empty URL raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("")
|
||||
|
||||
def test_no_host(self):
|
||||
"""Test that URL without host raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("http:///path")
|
||||
|
||||
|
||||
class TestValidateHeaders:
|
||||
"""Tests for validate_headers function."""
|
||||
|
||||
def test_valid_headers(self):
|
||||
"""Test that valid headers pass validation."""
|
||||
valid_headers = {
|
||||
"Authorization": "Bearer token123",
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom-Header": "value",
|
||||
}
|
||||
validate_headers(valid_headers) # Should not raise
|
||||
|
||||
def test_newline_in_header_name(self):
|
||||
"""Test that newlines in header names are rejected (HTTP header injection)."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_headers({"X-Bad\nHeader": "value"})
|
||||
assert "newline" in exc_info.value.message.lower()
|
||||
|
||||
def test_newline_in_header_value(self):
|
||||
"""Test that newlines in header values are rejected (HTTP header injection)."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_headers({"X-Header": "value\r\nX-Injected: malicious"})
|
||||
|
||||
def test_null_byte_in_header(self):
|
||||
"""Test that null bytes in headers are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_headers({"X-Header": "value\x00"})
|
||||
|
||||
def test_empty_headers(self):
|
||||
"""Test that empty headers dict passes validation."""
|
||||
validate_headers({})
|
||||
validate_headers(None)
|
||||
|
||||
|
||||
class TestValidateMCPServerConfig:
|
||||
"""Tests for the main validate_mcp_server_config function."""
|
||||
|
||||
def test_valid_stdio_config(self):
|
||||
"""Test valid stdio configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=["@modelcontextprotocol/server-filesystem"],
|
||||
env={"API_KEY": "secret"},
|
||||
) # Should not raise
|
||||
|
||||
def test_valid_sse_config(self):
|
||||
"""Test valid SSE configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="sse",
|
||||
url="https://api.example.com/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
) # Should not raise
|
||||
|
||||
def test_valid_http_config(self):
|
||||
"""Test valid streamable_http configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="streamable_http",
|
||||
url="https://api.example.com/mcp",
|
||||
) # Should not raise
|
||||
|
||||
def test_invalid_transport(self):
|
||||
"""Test that invalid transport type raises an error."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_mcp_server_config(transport="invalid")
|
||||
assert "Invalid transport type" in exc_info.value.message
|
||||
|
||||
def test_combined_validation_errors(self):
|
||||
"""Test that multiple validation errors are combined."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="bash", # Not allowed
|
||||
args=["../etc/passwd"], # Path traversal
|
||||
env={"PATH": "/malicious"}, # Dangerous env var
|
||||
)
|
||||
# All errors should be combined
|
||||
assert "not allowed" in exc_info.value.message
|
||||
assert "traversal" in exc_info.value.message.lower()
|
||||
|
||||
def test_non_strict_mode(self):
|
||||
"""Test that non-strict mode logs warnings instead of raising."""
|
||||
# Should not raise, but would log warnings
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="bash",
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def test_stdio_with_dangerous_args(self):
|
||||
"""Test stdio config with command injection attempt."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="node",
|
||||
args=["script.js; rm -rf /"],
|
||||
)
|
||||
|
||||
def test_sse_with_invalid_url(self):
|
||||
"""Test SSE config with invalid URL."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_mcp_server_config(
|
||||
transport="sse",
|
||||
url="ftp://example.com",
|
||||
)
|
||||
|
||||
|
||||
class TestMCPServerMetadataRequest:
|
||||
"""Tests for Pydantic model validation."""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""Test that valid request passes validation."""
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
request = MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=["@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
assert request.transport == "stdio"
|
||||
assert request.command == "npx"
|
||||
|
||||
def test_invalid_command_raises_validation_error(self):
|
||||
"""Test that invalid command raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="bash",
|
||||
)
|
||||
assert "not allowed" in str(exc_info.value).lower()
|
||||
|
||||
def test_command_injection_raises_validation_error(self):
|
||||
"""Test that command injection raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="node",
|
||||
args=["script.js; rm -rf /"],
|
||||
)
|
||||
|
||||
def test_invalid_url_raises_validation_error(self):
|
||||
"""Test that invalid URL raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MCPServerMetadataRequest(
|
||||
transport="sse",
|
||||
url="ftp://example.com",
|
||||
)
|
||||
Reference in New Issue
Block a user