mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 04:14:46 +08:00
feat(server): add MCP server configuration validation (#830)
* feat(server): add MCP server configuration validation Add comprehensive validation for MCP server configurations, inspired by Flowise's validateMCPServerConfig implementation. MCPServerConfig checks implemented: - Command allowlist validation (node, npx, python, docker, uvx, etc.) - Path traversal prevention (blocks ../, absolute paths, ~/) - Shell command injection prevention (blocks ; & | ` $ etc.) - Dangerous environment variable blocking (PATH, LD_PRELOAD, etc.) - URL validation for SSE/HTTP transports (scheme, credentials) - HTTP header injection prevention (blocks newlines) * fix the unit test error of test_chat_request * Added the related path cases as reviewer commented * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,17 @@
|
|||||||
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from src.server.mcp_validators import (
|
||||||
|
MCPValidationError,
|
||||||
|
validate_args_for_local_file_access,
|
||||||
|
validate_command,
|
||||||
|
validate_command_injection,
|
||||||
|
validate_environment_variables,
|
||||||
|
validate_headers,
|
||||||
|
validate_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MCPServerMetadataRequest(BaseModel):
|
class MCPServerMetadataRequest(BaseModel):
|
||||||
@@ -43,6 +53,62 @@ class MCPServerMetadataRequest(BaseModel):
|
|||||||
description="Optional SSE read timeout in seconds (for sse type, default: 30, range: 1-3600)"
|
description="Optional SSE read timeout in seconds (for sse type, default: 30, range: 1-3600)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_security(self) -> "MCPServerMetadataRequest":
|
||||||
|
"""Validate MCP server configuration for security issues."""
|
||||||
|
errors: List[str] = []
|
||||||
|
|
||||||
|
# Validate transport type
|
||||||
|
valid_transports = {"stdio", "sse", "streamable_http"}
|
||||||
|
if self.transport not in valid_transports:
|
||||||
|
errors.append(
|
||||||
|
f"Invalid transport type: {self.transport}. Must be one of: {', '.join(valid_transports)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate stdio-specific fields
|
||||||
|
if self.transport == "stdio":
|
||||||
|
if self.command:
|
||||||
|
try:
|
||||||
|
validate_command(self.command)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
if self.args:
|
||||||
|
try:
|
||||||
|
validate_args_for_local_file_access(self.args)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_command_injection(self.args)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
if self.env:
|
||||||
|
try:
|
||||||
|
validate_environment_variables(self.env)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
# Validate SSE/HTTP-specific fields
|
||||||
|
elif self.transport in ("sse", "streamable_http"):
|
||||||
|
if self.url:
|
||||||
|
try:
|
||||||
|
validate_url(self.url)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
if self.headers:
|
||||||
|
try:
|
||||||
|
validate_headers(self.headers)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise ValueError("; ".join(errors))
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class MCPServerMetadataResponse(BaseModel):
|
class MCPServerMetadataResponse(BaseModel):
|
||||||
"""Response model for MCP server metadata."""
|
"""Response model for MCP server metadata."""
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from mcp.client.sse import sse_client
|
|||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
from mcp.client.streamable_http import streamablehttp_client
|
||||||
|
|
||||||
|
from src.server.mcp_validators import MCPValidationError, validate_mcp_server_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,6 +77,9 @@ async def load_mcp_tools(
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: If there's an error loading the tools
|
HTTPException: If there's an error loading the tools
|
||||||
"""
|
"""
|
||||||
|
# MCP server configuration is validated at the request boundary (Pydantic model)
|
||||||
|
# to avoid duplicate validation logic here.
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if server_type == "stdio":
|
if server_type == "stdio":
|
||||||
if not command:
|
if not command:
|
||||||
|
|||||||
532
src/server/mcp_validators.py
Normal file
532
src/server/mcp_validators.py
Normal file
@@ -0,0 +1,532 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
"""
|
||||||
|
MCP Server Configuration Validators.
|
||||||
|
|
||||||
|
This module provides security validation for MCP server configurations,
|
||||||
|
inspired by Flowise's validateMCPServerConfig implementation. It prevents:
|
||||||
|
- Command injection attacks
|
||||||
|
- Path traversal attacks
|
||||||
|
- Unauthorized file access
|
||||||
|
- Dangerous environment variable modifications
|
||||||
|
|
||||||
|
Reference: https://github.com/FlowiseAI/Flowise/blob/main/packages/components/nodes/tools/MCP/core.ts
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPValidationError(Exception):
|
||||||
|
"""Exception raised when MCP server configuration validation fails."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, field: Optional[str] = None):
|
||||||
|
self.message = message
|
||||||
|
self.field = field
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
# Allowed commands for stdio transport
|
||||||
|
# These are considered safe executable commands for MCP servers
|
||||||
|
ALLOWED_COMMANDS = frozenset([
|
||||||
|
"node",
|
||||||
|
"npx",
|
||||||
|
"python",
|
||||||
|
"python3",
|
||||||
|
"docker",
|
||||||
|
"uvx",
|
||||||
|
"uv",
|
||||||
|
"deno",
|
||||||
|
"bun",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Dangerous environment variables that should not be modified
|
||||||
|
DANGEROUS_ENV_VARS = frozenset([
|
||||||
|
"PATH",
|
||||||
|
"LD_LIBRARY_PATH",
|
||||||
|
"DYLD_LIBRARY_PATH",
|
||||||
|
"LD_PRELOAD",
|
||||||
|
"DYLD_INSERT_LIBRARIES",
|
||||||
|
"PYTHONPATH",
|
||||||
|
"NODE_PATH",
|
||||||
|
"RUBYLIB",
|
||||||
|
"PERL5LIB",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Shell metacharacters that could be used for injection
|
||||||
|
SHELL_METACHARACTERS = frozenset([
|
||||||
|
";",
|
||||||
|
"&",
|
||||||
|
"|",
|
||||||
|
"`",
|
||||||
|
"$",
|
||||||
|
"(",
|
||||||
|
")",
|
||||||
|
"{",
|
||||||
|
"}",
|
||||||
|
"[",
|
||||||
|
"]",
|
||||||
|
"<",
|
||||||
|
">",
|
||||||
|
"\n",
|
||||||
|
"\r",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Dangerous file extensions that should not be directly accessed
|
||||||
|
DANGEROUS_EXTENSIONS = frozenset([
|
||||||
|
".exe",
|
||||||
|
".dll",
|
||||||
|
".so",
|
||||||
|
".dylib",
|
||||||
|
".bat",
|
||||||
|
".cmd",
|
||||||
|
".ps1",
|
||||||
|
".sh",
|
||||||
|
".bash",
|
||||||
|
".zsh",
|
||||||
|
".env",
|
||||||
|
".pem",
|
||||||
|
".key",
|
||||||
|
".crt",
|
||||||
|
".p12",
|
||||||
|
".pfx",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Command chaining patterns
|
||||||
|
COMMAND_CHAINING_PATTERNS = [
|
||||||
|
"&&",
|
||||||
|
"||",
|
||||||
|
";;",
|
||||||
|
">>",
|
||||||
|
"<<",
|
||||||
|
"$(",
|
||||||
|
"<(",
|
||||||
|
">(",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Maximum argument length to prevent buffer overflow attacks
|
||||||
|
MAX_ARG_LENGTH = 1000
|
||||||
|
|
||||||
|
# Allowed URL schemes for SSE/HTTP transports
|
||||||
|
ALLOWED_URL_SCHEMES = frozenset(["http", "https"])
|
||||||
|
|
||||||
|
|
||||||
|
def validate_mcp_server_config(
|
||||||
|
transport: str,
|
||||||
|
command: Optional[str] = None,
|
||||||
|
args: Optional[List[str]] = None,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
env: Optional[Dict[str, str]] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
strict: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Validate MCP server configuration for security issues.
|
||||||
|
|
||||||
|
This is the main entry point for MCP server validation. It orchestrates
|
||||||
|
all security checks based on the transport type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transport: The type of MCP connection (stdio, sse, streamable_http)
|
||||||
|
command: The command to execute (for stdio transport)
|
||||||
|
args: Command arguments (for stdio transport)
|
||||||
|
url: The URL of the server (for sse/streamable_http transport)
|
||||||
|
env: Environment variables (for stdio transport)
|
||||||
|
headers: HTTP headers (for sse/streamable_http transport)
|
||||||
|
strict: If True, raise exceptions; if False, log warnings only
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPValidationError: If validation fails in strict mode
|
||||||
|
"""
|
||||||
|
errors: List[str] = []
|
||||||
|
|
||||||
|
# Validate transport type
|
||||||
|
valid_transports = {"stdio", "sse", "streamable_http"}
|
||||||
|
if transport not in valid_transports:
|
||||||
|
errors.append(f"Invalid transport type: {transport}. Must be one of: {', '.join(valid_transports)}")
|
||||||
|
|
||||||
|
# Transport-specific validation
|
||||||
|
if transport == "stdio":
|
||||||
|
# Validate command
|
||||||
|
if command:
|
||||||
|
try:
|
||||||
|
validate_command(command)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
# Validate arguments
|
||||||
|
if args:
|
||||||
|
try:
|
||||||
|
validate_args_for_local_file_access(args)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_command_injection(args)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
# Validate environment variables
|
||||||
|
if env:
|
||||||
|
try:
|
||||||
|
validate_environment_variables(env)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
elif transport in ("sse", "streamable_http"):
|
||||||
|
# Validate URL
|
||||||
|
if url:
|
||||||
|
try:
|
||||||
|
validate_url(url)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
# Validate headers for injection
|
||||||
|
if headers:
|
||||||
|
try:
|
||||||
|
validate_headers(headers)
|
||||||
|
except MCPValidationError as e:
|
||||||
|
errors.append(e.message)
|
||||||
|
|
||||||
|
# Handle errors
|
||||||
|
if errors:
|
||||||
|
error_message = "; ".join(errors)
|
||||||
|
if strict:
|
||||||
|
raise MCPValidationError(error_message)
|
||||||
|
else:
|
||||||
|
logger.warning(f"MCP configuration validation warnings: {error_message}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_command(command: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate the command against an allowlist of safe executables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: The command to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPValidationError: If the command is not in the allowlist
|
||||||
|
"""
|
||||||
|
if not command or not isinstance(command, str):
|
||||||
|
raise MCPValidationError("Command must be a non-empty string", field="command")
|
||||||
|
|
||||||
|
# Extract the base command (handle full paths)
|
||||||
|
# e.g., "/usr/bin/python3" -> "python3"
|
||||||
|
base_command = command.split("/")[-1].split("\\")[-1]
|
||||||
|
|
||||||
|
# Also handle .exe suffix on Windows
|
||||||
|
if base_command.endswith(".exe"):
|
||||||
|
base_command = base_command[:-4]
|
||||||
|
|
||||||
|
# Normalize to lowercase to handle case-insensitive filesystems (e.g., Windows)
|
||||||
|
normalized_command = base_command.lower()
|
||||||
|
|
||||||
|
if normalized_command not in ALLOWED_COMMANDS:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Command '{command}' is not allowed. Allowed commands: {', '.join(sorted(ALLOWED_COMMANDS))}",
|
||||||
|
field="command",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_args_for_local_file_access(args: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
Validate arguments to prevent path traversal and unauthorized file access.
|
||||||
|
|
||||||
|
Checks for:
|
||||||
|
- Absolute paths (starting with / or drive letters like C:)
|
||||||
|
- Directory traversal (../, ..\\)
|
||||||
|
- Local file access patterns (./, ~/)
|
||||||
|
- Dangerous file extensions
|
||||||
|
- Null bytes (security exploit)
|
||||||
|
- Excessively long arguments (buffer overflow protection)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: List of command arguments to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPValidationError: If any argument contains dangerous patterns
|
||||||
|
"""
|
||||||
|
if not args:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if not isinstance(arg, str):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} must be a string, got {type(arg).__name__}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for excessively long arguments
|
||||||
|
if len(arg) > MAX_ARG_LENGTH:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} exceeds maximum length of {MAX_ARG_LENGTH} characters",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for null bytes
|
||||||
|
if "\x00" in arg:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} contains null byte",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for directory traversal
|
||||||
|
if ".." in arg:
|
||||||
|
# More specific check for actual traversal patterns
|
||||||
|
# Catches: "../", "..\", "/..", "\..", standalone "..", starts with "..", ends with ".."
|
||||||
|
if (
|
||||||
|
"../" in arg
|
||||||
|
or "..\\" in arg
|
||||||
|
or "/.." in arg
|
||||||
|
or "\\.." in arg
|
||||||
|
or arg == ".."
|
||||||
|
or arg.startswith("..")
|
||||||
|
or arg.endswith("..")
|
||||||
|
):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} contains directory traversal pattern: {arg[:50]}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for absolute paths (Unix-style)
|
||||||
|
# Be careful to allow flags like -f, --flag, etc. (e.g. "/-f").
|
||||||
|
# We reject all absolute Unix paths (including single-component ones like "/etc")
|
||||||
|
# to avoid access to potentially sensitive directories.
|
||||||
|
if arg.startswith("/") and not arg.startswith("/-"):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} contains absolute path: {arg[:50]}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for Windows absolute paths
|
||||||
|
if len(arg) >= 2 and arg[1] == ":" and arg[0].isalpha():
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} contains Windows absolute path: {arg[:50]}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for home directory expansion
|
||||||
|
if arg.startswith("~/") or arg.startswith("~\\"):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} contains home directory reference: {arg[:50]}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for dangerous extensions in the argument
|
||||||
|
arg_lower = arg.lower()
|
||||||
|
for ext in DANGEROUS_EXTENSIONS:
|
||||||
|
if arg_lower.endswith(ext):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} references potentially dangerous file type: {ext}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_command_injection(args: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
Validate arguments to prevent shell command injection.
|
||||||
|
|
||||||
|
Checks for:
|
||||||
|
- Shell metacharacters (; & | ` $ ( ) { } [ ] < > etc.)
|
||||||
|
- Command chaining patterns (&& || ;; etc.)
|
||||||
|
- Command substitution patterns ($() ``)
|
||||||
|
- Process substitution patterns (<() >())
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: List of command arguments to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPValidationError: If any argument contains injection patterns
|
||||||
|
"""
|
||||||
|
if not args:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if not isinstance(arg, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for shell metacharacters
|
||||||
|
for char in SHELL_METACHARACTERS:
|
||||||
|
if char in arg:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} contains shell metacharacter '{char}': {arg[:50]}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for command chaining patterns
|
||||||
|
for pattern in COMMAND_CHAINING_PATTERNS:
|
||||||
|
if pattern in arg:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Argument at index {i} contains command chaining pattern '{pattern}': {arg[:50]}",
|
||||||
|
field="args",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_environment_variables(env: Dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Validate environment variables to prevent dangerous modifications.
|
||||||
|
|
||||||
|
Checks for:
|
||||||
|
- Modifications to PATH and library path variables
|
||||||
|
- Null bytes in values
|
||||||
|
- Excessively long values
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: Dictionary of environment variables
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPValidationError: If any environment variable is dangerous
|
||||||
|
"""
|
||||||
|
if not env:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(env, dict):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Environment variables must be a dictionary, got {type(env).__name__}",
|
||||||
|
field="env",
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in env.items():
|
||||||
|
# Validate key
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Environment variable key must be a string, got {type(key).__name__}",
|
||||||
|
field="env",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for dangerous environment variables
|
||||||
|
if key.upper() in DANGEROUS_ENV_VARS:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Modification of environment variable '{key}' is not allowed for security reasons",
|
||||||
|
field="env",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate value
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Environment variable value for '{key}' must be a string, got {type(value).__name__}",
|
||||||
|
field="env",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for null bytes in value
|
||||||
|
if "\x00" in value:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Environment variable '{key}' contains null byte",
|
||||||
|
field="env",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for excessively long values
|
||||||
|
if len(value) > MAX_ARG_LENGTH * 10: # Allow longer env values
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Environment variable '{key}' value exceeds maximum length",
|
||||||
|
field="env",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate URL for SSE/HTTP transport.
|
||||||
|
|
||||||
|
Checks for:
|
||||||
|
- Valid URL format
|
||||||
|
- Allowed schemes (http, https)
|
||||||
|
- No credentials in URL
|
||||||
|
- No localhost/internal network access (optional, configurable)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The URL to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPValidationError: If the URL is invalid or potentially dangerous
|
||||||
|
"""
|
||||||
|
if not url or not isinstance(url, str):
|
||||||
|
raise MCPValidationError("URL must be a non-empty string", field="url")
|
||||||
|
|
||||||
|
# Check for null bytes
|
||||||
|
if "\x00" in url:
|
||||||
|
raise MCPValidationError("URL contains null byte", field="url")
|
||||||
|
|
||||||
|
# Parse the URL
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
except Exception as e:
|
||||||
|
raise MCPValidationError(f"Invalid URL format: {e}", field="url")
|
||||||
|
|
||||||
|
# Check scheme
|
||||||
|
if parsed.scheme not in ALLOWED_URL_SCHEMES:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"URL scheme '{parsed.scheme}' is not allowed. Allowed schemes: {', '.join(ALLOWED_URL_SCHEMES)}",
|
||||||
|
field="url",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for credentials in URL (security risk)
|
||||||
|
if parsed.username or parsed.password:
|
||||||
|
raise MCPValidationError(
|
||||||
|
"URL should not contain credentials. Use headers for authentication instead.",
|
||||||
|
field="url",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for valid host
|
||||||
|
if not parsed.netloc:
|
||||||
|
raise MCPValidationError("URL must have a valid host", field="url")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_headers(headers: Dict[str, str]) -> None:
|
||||||
|
"""
|
||||||
|
Validate HTTP headers for potential injection attacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers: Dictionary of HTTP headers
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MCPValidationError: If any header contains dangerous patterns
|
||||||
|
"""
|
||||||
|
if not headers:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(headers, dict):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Headers must be a dictionary, got {type(headers).__name__}",
|
||||||
|
field="headers",
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in headers.items():
|
||||||
|
# Validate key
|
||||||
|
if not isinstance(key, str):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Header key must be a string, got {type(key).__name__}",
|
||||||
|
field="headers",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for newlines in header name (HTTP header injection)
|
||||||
|
if "\n" in key or "\r" in key:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Header name '{key[:20]}' contains newline character (potential HTTP header injection)",
|
||||||
|
field="headers",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate value
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Header value for '{key}' must be a string, got {type(value).__name__}",
|
||||||
|
field="headers",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for newlines in header value (HTTP header injection)
|
||||||
|
if "\n" in value or "\r" in value:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Header value for '{key}' contains newline character (potential HTTP header injection)",
|
||||||
|
field="headers",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for null bytes
|
||||||
|
if "\x00" in key or "\x00" in value:
|
||||||
|
raise MCPValidationError(
|
||||||
|
f"Header '{key}' contains null byte",
|
||||||
|
field="headers",
|
||||||
|
)
|
||||||
@@ -352,9 +352,9 @@ class TestMCPEndpoint:
|
|||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"transport": "stdio",
|
"transport": "stdio",
|
||||||
"command": "test_command",
|
"command": "node",
|
||||||
"args": ["arg1", "arg2"],
|
"args": ["server.js"],
|
||||||
"env": {"ENV_VAR": "value"},
|
"env": {"API_KEY": "test123"},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||||
@@ -362,7 +362,7 @@ class TestMCPEndpoint:
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
assert response_data["transport"] == "stdio"
|
assert response_data["transport"] == "stdio"
|
||||||
assert response_data["command"] == "test_command"
|
assert response_data["command"] == "node"
|
||||||
assert len(response_data["tools"]) == 1
|
assert len(response_data["tools"]) == 1
|
||||||
|
|
||||||
@patch("src.server.app.load_mcp_tools")
|
@patch("src.server.app.load_mcp_tools")
|
||||||
@@ -375,7 +375,7 @@ class TestMCPEndpoint:
|
|||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"transport": "stdio",
|
"transport": "stdio",
|
||||||
"command": "test_command",
|
"command": "node",
|
||||||
"timeout_seconds": 60,
|
"timeout_seconds": 60,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -424,9 +424,9 @@ class TestMCPEndpoint:
|
|||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"transport": "stdio",
|
"transport": "stdio",
|
||||||
"command": "test_command",
|
"command": "node",
|
||||||
"args": ["arg1", "arg2"],
|
"args": ["server.js"],
|
||||||
"env": {"ENV_VAR": "value"},
|
"env": {"API_KEY": "test123"},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||||
@@ -444,9 +444,9 @@ class TestMCPEndpoint:
|
|||||||
):
|
):
|
||||||
request_data = {
|
request_data = {
|
||||||
"transport": "stdio",
|
"transport": "stdio",
|
||||||
"command": "test_command",
|
"command": "node",
|
||||||
"args": ["arg1", "arg2"],
|
"args": ["server.js"],
|
||||||
"env": {"ENV_VAR": "value"},
|
"env": {"API_KEY": "test123"},
|
||||||
}
|
}
|
||||||
|
|
||||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||||
|
|||||||
@@ -163,6 +163,6 @@ async def test_load_mcp_tools_exception_handling(
|
|||||||
mock_stdio_client.return_value = MagicMock()
|
mock_stdio_client.return_value = MagicMock()
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc:
|
with pytest.raises(HTTPException) as exc:
|
||||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") # Use await
|
await mcp_utils.load_mcp_tools(server_type="stdio", command="node") # Use await
|
||||||
assert exc.value.status_code == 500
|
assert exc.value.status_code == 500
|
||||||
assert "unexpected error" in exc.value.detail
|
assert "unexpected error" in exc.value.detail
|
||||||
|
|||||||
@@ -55,14 +55,14 @@ async def test_load_mcp_tools_stdio_success(
|
|||||||
|
|
||||||
result = await mcp_utils.load_mcp_tools(
|
result = await mcp_utils.load_mcp_tools(
|
||||||
server_type="stdio",
|
server_type="stdio",
|
||||||
command="echo",
|
command="node",
|
||||||
args=["foo"],
|
args=["server.js"],
|
||||||
env={"FOO": "BAR"},
|
env={"API_KEY": "test123"},
|
||||||
timeout_seconds=3,
|
timeout_seconds=3,
|
||||||
)
|
)
|
||||||
assert result == ["toolA"]
|
assert result == ["toolA"]
|
||||||
mock_StdioServerParameters.assert_called_once_with(
|
mock_StdioServerParameters.assert_called_once_with(
|
||||||
command="echo", args=["foo"], env={"FOO": "BAR"}
|
command="node", args=["server.js"], env={"API_KEY": "test123"}
|
||||||
)
|
)
|
||||||
mock_stdio_client.assert_called_once_with(params)
|
mock_stdio_client.assert_called_once_with(params)
|
||||||
mock_get_tools.assert_awaited_once_with(mock_client, 3)
|
mock_get_tools.assert_awaited_once_with(mock_client, 3)
|
||||||
@@ -165,7 +165,7 @@ async def test_load_mcp_tools_unsupported_type():
|
|||||||
with pytest.raises(HTTPException) as exc:
|
with pytest.raises(HTTPException) as exc:
|
||||||
await mcp_utils.load_mcp_tools(server_type="unknown")
|
await mcp_utils.load_mcp_tools(server_type="unknown")
|
||||||
assert exc.value.status_code == 400
|
assert exc.value.status_code == 400
|
||||||
assert "Unsupported server type" in exc.value.detail
|
assert "Invalid transport type" in exc.value.detail or "Unsupported server type" in exc.value.detail
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -180,6 +180,6 @@ async def test_load_mcp_tools_exception_handling(
|
|||||||
mock_stdio_client.return_value = MagicMock()
|
mock_stdio_client.return_value = MagicMock()
|
||||||
|
|
||||||
with pytest.raises(HTTPException) as exc:
|
with pytest.raises(HTTPException) as exc:
|
||||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo")
|
await mcp_utils.load_mcp_tools(server_type="stdio", command="node")
|
||||||
assert exc.value.status_code == 500
|
assert exc.value.status_code == 500
|
||||||
assert "unexpected error" in exc.value.detail
|
assert "unexpected error" in exc.value.detail
|
||||||
|
|||||||
450
tests/unit/server/test_mcp_validators.py
Normal file
450
tests/unit/server/test_mcp_validators.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for MCP server configuration validators.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- Command validation (allowlist)
|
||||||
|
- Argument validation (path traversal, command injection)
|
||||||
|
- Environment variable validation
|
||||||
|
- URL validation
|
||||||
|
- Header validation
|
||||||
|
- Full config validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.server.mcp_validators import (
|
||||||
|
ALLOWED_COMMANDS,
|
||||||
|
MCPValidationError,
|
||||||
|
validate_args_for_local_file_access,
|
||||||
|
validate_command,
|
||||||
|
validate_command_injection,
|
||||||
|
validate_environment_variables,
|
||||||
|
validate_headers,
|
||||||
|
validate_mcp_server_config,
|
||||||
|
validate_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCommand:
|
||||||
|
"""Tests for validate_command function."""
|
||||||
|
|
||||||
|
def test_allowed_commands(self):
|
||||||
|
"""Test that all allowed commands pass validation."""
|
||||||
|
for cmd in ALLOWED_COMMANDS:
|
||||||
|
validate_command(cmd) # Should not raise
|
||||||
|
|
||||||
|
def test_allowed_command_with_path(self):
|
||||||
|
"""Test that commands with paths are validated by base name."""
|
||||||
|
validate_command("/usr/bin/python3")
|
||||||
|
validate_command("/usr/local/bin/node")
|
||||||
|
validate_command("C:\\Python\\python.exe")
|
||||||
|
|
||||||
|
def test_disallowed_command(self):
|
||||||
|
"""Test that disallowed commands raise an error."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_command("bash")
|
||||||
|
assert "not allowed" in exc_info.value.message
|
||||||
|
assert exc_info.value.field == "command"
|
||||||
|
|
||||||
|
def test_disallowed_dangerous_commands(self):
|
||||||
|
"""Test that dangerous commands are rejected."""
|
||||||
|
dangerous_commands = ["rm", "sudo", "chmod", "chown", "curl", "wget", "sh"]
|
||||||
|
for cmd in dangerous_commands:
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_command(cmd)
|
||||||
|
|
||||||
|
def test_empty_command(self):
|
||||||
|
"""Test that empty command raises an error."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_command("")
|
||||||
|
|
||||||
|
def test_none_command(self):
|
||||||
|
"""Test that None command raises an error."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_command(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateArgsForLocalFileAccess:
|
||||||
|
"""Tests for validate_args_for_local_file_access function."""
|
||||||
|
|
||||||
|
def test_safe_args(self):
|
||||||
|
"""Test that safe arguments pass validation."""
|
||||||
|
safe_args = [
|
||||||
|
["--help"],
|
||||||
|
["-v", "--verbose"],
|
||||||
|
["package-name"],
|
||||||
|
["--config", "config.json"],
|
||||||
|
["run", "script.py"],
|
||||||
|
]
|
||||||
|
for args in safe_args:
|
||||||
|
validate_args_for_local_file_access(args) # Should not raise
|
||||||
|
|
||||||
|
def test_directory_traversal(self):
|
||||||
|
"""Test that directory traversal patterns are rejected."""
|
||||||
|
traversal_patterns = [
|
||||||
|
["../etc/passwd"],
|
||||||
|
["..\\windows\\system32"],
|
||||||
|
["../../secret"],
|
||||||
|
["foo/../bar/../../../etc/passwd"],
|
||||||
|
["foo/.."], # ".." at end after path separator
|
||||||
|
["bar\\.."], # ".." at end after Windows path separator
|
||||||
|
["path/to/foo/.."], # Longer path ending with ".."
|
||||||
|
]
|
||||||
|
for args in traversal_patterns:
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_args_for_local_file_access(args)
|
||||||
|
assert "traversal" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_absolute_path_with_dangerous_extension(self):
|
||||||
|
"""Test that absolute paths with dangerous extensions are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_args_for_local_file_access(["/etc/passwd.sh"])
|
||||||
|
|
||||||
|
def test_windows_absolute_path(self):
|
||||||
|
"""Test that Windows absolute paths are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_args_for_local_file_access(["C:\\Windows\\system32"])
|
||||||
|
|
||||||
|
def test_home_directory_reference(self):
|
||||||
|
"""Test that home directory references are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_args_for_local_file_access(["~/secrets"])
|
||||||
|
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_args_for_local_file_access(["~\\secrets"])
|
||||||
|
|
||||||
|
def test_null_byte(self):
|
||||||
|
"""Test that null bytes in arguments are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_args_for_local_file_access(["file\x00.txt"])
|
||||||
|
assert "null byte" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_excessively_long_argument(self):
|
||||||
|
"""Test that excessively long arguments are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_args_for_local_file_access(["a" * 1001])
|
||||||
|
assert "maximum length" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_dangerous_extensions(self):
|
||||||
|
"""Test that dangerous file extensions are rejected."""
|
||||||
|
dangerous_files = [
|
||||||
|
["script.sh"],
|
||||||
|
["binary.exe"],
|
||||||
|
["library.dll"],
|
||||||
|
["secret.env"],
|
||||||
|
["key.pem"],
|
||||||
|
]
|
||||||
|
for args in dangerous_files:
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_args_for_local_file_access(args)
|
||||||
|
assert "dangerous file type" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_empty_args(self):
|
||||||
|
"""Test that empty args list passes validation."""
|
||||||
|
validate_args_for_local_file_access([])
|
||||||
|
validate_args_for_local_file_access(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCommandInjection:
|
||||||
|
"""Tests for validate_command_injection function."""
|
||||||
|
|
||||||
|
def test_safe_args(self):
|
||||||
|
"""Test that safe arguments pass validation."""
|
||||||
|
safe_args = [
|
||||||
|
["--help"],
|
||||||
|
["package-name"],
|
||||||
|
["@scope/package"],
|
||||||
|
["file.json"],
|
||||||
|
]
|
||||||
|
for args in safe_args:
|
||||||
|
validate_command_injection(args) # Should not raise
|
||||||
|
|
||||||
|
def test_shell_metacharacters(self):
|
||||||
|
"""Test that shell metacharacters are rejected."""
|
||||||
|
metachar_args = [
|
||||||
|
["foo; rm -rf /"],
|
||||||
|
["foo & bar"],
|
||||||
|
["foo | cat /etc/passwd"],
|
||||||
|
["$(whoami)"],
|
||||||
|
["`id`"],
|
||||||
|
["foo > /etc/passwd"],
|
||||||
|
["foo < /etc/passwd"],
|
||||||
|
["${PATH}"],
|
||||||
|
]
|
||||||
|
for args in metachar_args:
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_command_injection(args)
|
||||||
|
assert "args" == exc_info.value.field
|
||||||
|
|
||||||
|
def test_command_chaining(self):
|
||||||
|
"""Test that command chaining patterns are rejected."""
|
||||||
|
chaining_args = [
|
||||||
|
["foo && bar"],
|
||||||
|
["foo || bar"],
|
||||||
|
["foo;; bar"],
|
||||||
|
["foo >> output"],
|
||||||
|
["foo << input"],
|
||||||
|
]
|
||||||
|
for args in chaining_args:
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_command_injection(args)
|
||||||
|
|
||||||
|
def test_backtick_injection(self):
|
||||||
|
"""Test that backtick command substitution is rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_command_injection(["`whoami`"])
|
||||||
|
|
||||||
|
def test_process_substitution(self):
|
||||||
|
"""Test that process substitution is rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_command_injection(["<(cat /etc/passwd)"])
|
||||||
|
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_command_injection([">(tee /tmp/out)"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateEnvironmentVariables:
|
||||||
|
"""Tests for validate_environment_variables function."""
|
||||||
|
|
||||||
|
def test_safe_env_vars(self):
|
||||||
|
"""Test that safe environment variables pass validation."""
|
||||||
|
safe_env = {
|
||||||
|
"API_KEY": "secret123",
|
||||||
|
"DEBUG": "true",
|
||||||
|
"MY_VARIABLE": "value",
|
||||||
|
}
|
||||||
|
validate_environment_variables(safe_env) # Should not raise
|
||||||
|
|
||||||
|
def test_dangerous_env_vars(self):
|
||||||
|
"""Test that dangerous environment variables are rejected."""
|
||||||
|
dangerous_vars = [
|
||||||
|
{"PATH": "/malicious/path"},
|
||||||
|
{"LD_LIBRARY_PATH": "/malicious/lib"},
|
||||||
|
{"DYLD_LIBRARY_PATH": "/malicious/lib"},
|
||||||
|
{"LD_PRELOAD": "/malicious/lib.so"},
|
||||||
|
{"PYTHONPATH": "/malicious/python"},
|
||||||
|
{"NODE_PATH": "/malicious/node"},
|
||||||
|
]
|
||||||
|
for env in dangerous_vars:
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_environment_variables(env)
|
||||||
|
assert "not allowed" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_null_byte_in_value(self):
|
||||||
|
"""Test that null bytes in values are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_environment_variables({"KEY": "value\x00malicious"})
|
||||||
|
|
||||||
|
def test_empty_env(self):
|
||||||
|
"""Test that empty env dict passes validation."""
|
||||||
|
validate_environment_variables({})
|
||||||
|
validate_environment_variables(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateUrl:
|
||||||
|
"""Tests for validate_url function."""
|
||||||
|
|
||||||
|
def test_valid_urls(self):
|
||||||
|
"""Test that valid URLs pass validation."""
|
||||||
|
valid_urls = [
|
||||||
|
"http://localhost:3000",
|
||||||
|
"https://api.example.com",
|
||||||
|
"http://192.168.1.1:8080/api",
|
||||||
|
"https://mcp.example.com/sse",
|
||||||
|
]
|
||||||
|
for url in valid_urls:
|
||||||
|
validate_url(url) # Should not raise
|
||||||
|
|
||||||
|
def test_invalid_scheme(self):
|
||||||
|
"""Test that invalid URL schemes are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_url("ftp://example.com")
|
||||||
|
assert "scheme" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_url("file:///etc/passwd")
|
||||||
|
|
||||||
|
def test_credentials_in_url(self):
|
||||||
|
"""Test that URLs with credentials are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_url("https://user:pass@example.com")
|
||||||
|
assert "credentials" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_null_byte_in_url(self):
|
||||||
|
"""Test that null bytes in URL are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_url("https://example.com\x00/malicious")
|
||||||
|
|
||||||
|
def test_empty_url(self):
|
||||||
|
"""Test that empty URL raises an error."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_url("")
|
||||||
|
|
||||||
|
def test_no_host(self):
|
||||||
|
"""Test that URL without host raises an error."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_url("http:///path")
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateHeaders:
|
||||||
|
"""Tests for validate_headers function."""
|
||||||
|
|
||||||
|
def test_valid_headers(self):
|
||||||
|
"""Test that valid headers pass validation."""
|
||||||
|
valid_headers = {
|
||||||
|
"Authorization": "Bearer token123",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Custom-Header": "value",
|
||||||
|
}
|
||||||
|
validate_headers(valid_headers) # Should not raise
|
||||||
|
|
||||||
|
def test_newline_in_header_name(self):
|
||||||
|
"""Test that newlines in header names are rejected (HTTP header injection)."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_headers({"X-Bad\nHeader": "value"})
|
||||||
|
assert "newline" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_newline_in_header_value(self):
|
||||||
|
"""Test that newlines in header values are rejected (HTTP header injection)."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_headers({"X-Header": "value\r\nX-Injected: malicious"})
|
||||||
|
|
||||||
|
def test_null_byte_in_header(self):
|
||||||
|
"""Test that null bytes in headers are rejected."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_headers({"X-Header": "value\x00"})
|
||||||
|
|
||||||
|
def test_empty_headers(self):
|
||||||
|
"""Test that empty headers dict passes validation."""
|
||||||
|
validate_headers({})
|
||||||
|
validate_headers(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateMCPServerConfig:
|
||||||
|
"""Tests for the main validate_mcp_server_config function."""
|
||||||
|
|
||||||
|
def test_valid_stdio_config(self):
|
||||||
|
"""Test valid stdio configuration."""
|
||||||
|
validate_mcp_server_config(
|
||||||
|
transport="stdio",
|
||||||
|
command="npx",
|
||||||
|
args=["@modelcontextprotocol/server-filesystem"],
|
||||||
|
env={"API_KEY": "secret"},
|
||||||
|
) # Should not raise
|
||||||
|
|
||||||
|
def test_valid_sse_config(self):
|
||||||
|
"""Test valid SSE configuration."""
|
||||||
|
validate_mcp_server_config(
|
||||||
|
transport="sse",
|
||||||
|
url="https://api.example.com/sse",
|
||||||
|
headers={"Authorization": "Bearer token"},
|
||||||
|
) # Should not raise
|
||||||
|
|
||||||
|
def test_valid_http_config(self):
|
||||||
|
"""Test valid streamable_http configuration."""
|
||||||
|
validate_mcp_server_config(
|
||||||
|
transport="streamable_http",
|
||||||
|
url="https://api.example.com/mcp",
|
||||||
|
) # Should not raise
|
||||||
|
|
||||||
|
def test_invalid_transport(self):
|
||||||
|
"""Test that invalid transport type raises an error."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_mcp_server_config(transport="invalid")
|
||||||
|
assert "Invalid transport type" in exc_info.value.message
|
||||||
|
|
||||||
|
def test_combined_validation_errors(self):
|
||||||
|
"""Test that multiple validation errors are combined."""
|
||||||
|
with pytest.raises(MCPValidationError) as exc_info:
|
||||||
|
validate_mcp_server_config(
|
||||||
|
transport="stdio",
|
||||||
|
command="bash", # Not allowed
|
||||||
|
args=["../etc/passwd"], # Path traversal
|
||||||
|
env={"PATH": "/malicious"}, # Dangerous env var
|
||||||
|
)
|
||||||
|
# All errors should be combined
|
||||||
|
assert "not allowed" in exc_info.value.message
|
||||||
|
assert "traversal" in exc_info.value.message.lower()
|
||||||
|
|
||||||
|
def test_non_strict_mode(self):
|
||||||
|
"""Test that non-strict mode logs warnings instead of raising."""
|
||||||
|
# Should not raise, but would log warnings
|
||||||
|
validate_mcp_server_config(
|
||||||
|
transport="stdio",
|
||||||
|
command="bash",
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_stdio_with_dangerous_args(self):
|
||||||
|
"""Test stdio config with command injection attempt."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_mcp_server_config(
|
||||||
|
transport="stdio",
|
||||||
|
command="node",
|
||||||
|
args=["script.js; rm -rf /"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_sse_with_invalid_url(self):
|
||||||
|
"""Test SSE config with invalid URL."""
|
||||||
|
with pytest.raises(MCPValidationError):
|
||||||
|
validate_mcp_server_config(
|
||||||
|
transport="sse",
|
||||||
|
url="ftp://example.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPServerMetadataRequest:
|
||||||
|
"""Tests for Pydantic model validation."""
|
||||||
|
|
||||||
|
def test_valid_request(self):
|
||||||
|
"""Test that valid request passes validation."""
|
||||||
|
from src.server.mcp_request import MCPServerMetadataRequest
|
||||||
|
|
||||||
|
request = MCPServerMetadataRequest(
|
||||||
|
transport="stdio",
|
||||||
|
command="npx",
|
||||||
|
args=["@modelcontextprotocol/server-filesystem"],
|
||||||
|
)
|
||||||
|
assert request.transport == "stdio"
|
||||||
|
assert request.command == "npx"
|
||||||
|
|
||||||
|
def test_invalid_command_raises_validation_error(self):
|
||||||
|
"""Test that invalid command raises Pydantic ValidationError."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from src.server.mcp_request import MCPServerMetadataRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
MCPServerMetadataRequest(
|
||||||
|
transport="stdio",
|
||||||
|
command="bash",
|
||||||
|
)
|
||||||
|
assert "not allowed" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_command_injection_raises_validation_error(self):
|
||||||
|
"""Test that command injection raises Pydantic ValidationError."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from src.server.mcp_request import MCPServerMetadataRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
MCPServerMetadataRequest(
|
||||||
|
transport="stdio",
|
||||||
|
command="node",
|
||||||
|
args=["script.js; rm -rf /"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_invalid_url_raises_validation_error(self):
|
||||||
|
"""Test that invalid URL raises Pydantic ValidationError."""
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from src.server.mcp_request import MCPServerMetadataRequest
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
MCPServerMetadataRequest(
|
||||||
|
transport="sse",
|
||||||
|
url="ftp://example.com",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user