feat(server): add MCP server configuration validation (#830)

* feat(server): add MCP server configuration validation

Add comprehensive validation for MCP server configurations,
inspired by Flowise's validateMCPServerConfig implementation.

MCPServerConfig checks implemented:
- Command allowlist validation (node, npx, python, docker, uvx, etc.)
- Path traversal prevention (blocks ../, absolute paths, ~/)
- Shell command injection prevention (blocks ; & | ` $ etc.)
- Dangerous environment variable blocking (PATH, LD_PRELOAD, etc.)
- URL validation for SSE/HTTP transports (scheme, credentials)
- HTTP header injection prevention (blocks newlines)

* fix the unit test error of test_chat_request

* Added the related path cases as reviewer commented

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Willem Jiang
2026-01-24 17:32:17 +08:00
committed by GitHub
parent c0849af37e
commit 612bddd3fb
7 changed files with 1072 additions and 19 deletions

View File

@@ -3,7 +3,17 @@
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from src.server.mcp_validators import (
MCPValidationError,
validate_args_for_local_file_access,
validate_command,
validate_command_injection,
validate_environment_variables,
validate_headers,
validate_url,
)
class MCPServerMetadataRequest(BaseModel):
@@ -43,6 +53,62 @@ class MCPServerMetadataRequest(BaseModel):
description="Optional SSE read timeout in seconds (for sse type, default: 30, range: 1-3600)"
)
@model_validator(mode="after")
def validate_security(self) -> "MCPServerMetadataRequest":
"""Validate MCP server configuration for security issues."""
errors: List[str] = []
# Validate transport type
valid_transports = {"stdio", "sse", "streamable_http"}
if self.transport not in valid_transports:
errors.append(
f"Invalid transport type: {self.transport}. Must be one of: {', '.join(valid_transports)}"
)
# Validate stdio-specific fields
if self.transport == "stdio":
if self.command:
try:
validate_command(self.command)
except MCPValidationError as e:
errors.append(e.message)
if self.args:
try:
validate_args_for_local_file_access(self.args)
except MCPValidationError as e:
errors.append(e.message)
try:
validate_command_injection(self.args)
except MCPValidationError as e:
errors.append(e.message)
if self.env:
try:
validate_environment_variables(self.env)
except MCPValidationError as e:
errors.append(e.message)
# Validate SSE/HTTP-specific fields
elif self.transport in ("sse", "streamable_http"):
if self.url:
try:
validate_url(self.url)
except MCPValidationError as e:
errors.append(e.message)
if self.headers:
try:
validate_headers(self.headers)
except MCPValidationError as e:
errors.append(e.message)
if errors:
raise ValueError("; ".join(errors))
return self
class MCPServerMetadataResponse(BaseModel):
"""Response model for MCP server metadata."""

View File

@@ -11,6 +11,8 @@ from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from src.server.mcp_validators import MCPValidationError, validate_mcp_server_config
logger = logging.getLogger(__name__)
@@ -75,6 +77,9 @@ async def load_mcp_tools(
Raises:
HTTPException: If there's an error loading the tools
"""
# MCP server configuration is validated at the request boundary (Pydantic model)
# to avoid duplicate validation logic here.
try:
if server_type == "stdio":
if not command:

View File

@@ -0,0 +1,532 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
MCP Server Configuration Validators.
This module provides security validation for MCP server configurations,
inspired by Flowise's validateMCPServerConfig implementation. It prevents:
- Command injection attacks
- Path traversal attacks
- Unauthorized file access
- Dangerous environment variable modifications
Reference: https://github.com/FlowiseAI/Flowise/blob/main/packages/components/nodes/tools/MCP/core.ts
"""
import logging
from typing import Dict, List, Optional
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
class MCPValidationError(Exception):
"""Exception raised when MCP server configuration validation fails."""
def __init__(self, message: str, field: Optional[str] = None):
self.message = message
self.field = field
super().__init__(self.message)
# Allowed commands for stdio transport
# These are considered safe executable commands for MCP servers
ALLOWED_COMMANDS = frozenset([
"node",
"npx",
"python",
"python3",
"docker",
"uvx",
"uv",
"deno",
"bun",
])
# Dangerous environment variables that should not be modified
DANGEROUS_ENV_VARS = frozenset([
"PATH",
"LD_LIBRARY_PATH",
"DYLD_LIBRARY_PATH",
"LD_PRELOAD",
"DYLD_INSERT_LIBRARIES",
"PYTHONPATH",
"NODE_PATH",
"RUBYLIB",
"PERL5LIB",
])
# Shell metacharacters that could be used for injection
SHELL_METACHARACTERS = frozenset([
";",
"&",
"|",
"`",
"$",
"(",
")",
"{",
"}",
"[",
"]",
"<",
">",
"\n",
"\r",
])
# Dangerous file extensions that should not be directly accessed
DANGEROUS_EXTENSIONS = frozenset([
".exe",
".dll",
".so",
".dylib",
".bat",
".cmd",
".ps1",
".sh",
".bash",
".zsh",
".env",
".pem",
".key",
".crt",
".p12",
".pfx",
])
# Command chaining patterns
COMMAND_CHAINING_PATTERNS = [
"&&",
"||",
";;",
">>",
"<<",
"$(",
"<(",
">(",
]
# Maximum argument length to prevent buffer overflow attacks
MAX_ARG_LENGTH = 1000
# Allowed URL schemes for SSE/HTTP transports
ALLOWED_URL_SCHEMES = frozenset(["http", "https"])
def validate_mcp_server_config(
transport: str,
command: Optional[str] = None,
args: Optional[List[str]] = None,
url: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
headers: Optional[Dict[str, str]] = None,
strict: bool = True,
) -> None:
"""
Validate MCP server configuration for security issues.
This is the main entry point for MCP server validation. It orchestrates
all security checks based on the transport type.
Args:
transport: The type of MCP connection (stdio, sse, streamable_http)
command: The command to execute (for stdio transport)
args: Command arguments (for stdio transport)
url: The URL of the server (for sse/streamable_http transport)
env: Environment variables (for stdio transport)
headers: HTTP headers (for sse/streamable_http transport)
strict: If True, raise exceptions; if False, log warnings only
Raises:
MCPValidationError: If validation fails in strict mode
"""
errors: List[str] = []
# Validate transport type
valid_transports = {"stdio", "sse", "streamable_http"}
if transport not in valid_transports:
errors.append(f"Invalid transport type: {transport}. Must be one of: {', '.join(valid_transports)}")
# Transport-specific validation
if transport == "stdio":
# Validate command
if command:
try:
validate_command(command)
except MCPValidationError as e:
errors.append(e.message)
# Validate arguments
if args:
try:
validate_args_for_local_file_access(args)
except MCPValidationError as e:
errors.append(e.message)
try:
validate_command_injection(args)
except MCPValidationError as e:
errors.append(e.message)
# Validate environment variables
if env:
try:
validate_environment_variables(env)
except MCPValidationError as e:
errors.append(e.message)
elif transport in ("sse", "streamable_http"):
# Validate URL
if url:
try:
validate_url(url)
except MCPValidationError as e:
errors.append(e.message)
# Validate headers for injection
if headers:
try:
validate_headers(headers)
except MCPValidationError as e:
errors.append(e.message)
# Handle errors
if errors:
error_message = "; ".join(errors)
if strict:
raise MCPValidationError(error_message)
else:
logger.warning(f"MCP configuration validation warnings: {error_message}")
def validate_command(command: str) -> None:
"""
Validate the command against an allowlist of safe executables.
Args:
command: The command to validate
Raises:
MCPValidationError: If the command is not in the allowlist
"""
if not command or not isinstance(command, str):
raise MCPValidationError("Command must be a non-empty string", field="command")
# Extract the base command (handle full paths)
# e.g., "/usr/bin/python3" -> "python3"
base_command = command.split("/")[-1].split("\\")[-1]
# Also handle .exe suffix on Windows
if base_command.endswith(".exe"):
base_command = base_command[:-4]
# Normalize to lowercase to handle case-insensitive filesystems (e.g., Windows)
normalized_command = base_command.lower()
if normalized_command not in ALLOWED_COMMANDS:
raise MCPValidationError(
f"Command '{command}' is not allowed. Allowed commands: {', '.join(sorted(ALLOWED_COMMANDS))}",
field="command",
)
def validate_args_for_local_file_access(args: List[str]) -> None:
"""
Validate arguments to prevent path traversal and unauthorized file access.
Checks for:
- Absolute paths (starting with / or drive letters like C:)
- Directory traversal (../, ..\\)
- Local file access patterns (./, ~/)
- Dangerous file extensions
- Null bytes (security exploit)
- Excessively long arguments (buffer overflow protection)
Args:
args: List of command arguments to validate
Raises:
MCPValidationError: If any argument contains dangerous patterns
"""
if not args:
return
for i, arg in enumerate(args):
if not isinstance(arg, str):
raise MCPValidationError(
f"Argument at index {i} must be a string, got {type(arg).__name__}",
field="args",
)
# Check for excessively long arguments
if len(arg) > MAX_ARG_LENGTH:
raise MCPValidationError(
f"Argument at index {i} exceeds maximum length of {MAX_ARG_LENGTH} characters",
field="args",
)
# Check for null bytes
if "\x00" in arg:
raise MCPValidationError(
f"Argument at index {i} contains null byte",
field="args",
)
# Check for directory traversal
if ".." in arg:
# More specific check for actual traversal patterns
# Catches: "../", "..\", "/..", "\..", standalone "..", starts with "..", ends with ".."
if (
"../" in arg
or "..\\" in arg
or "/.." in arg
or "\\.." in arg
or arg == ".."
or arg.startswith("..")
or arg.endswith("..")
):
raise MCPValidationError(
f"Argument at index {i} contains directory traversal pattern: {arg[:50]}",
field="args",
)
# Check for absolute paths (Unix-style)
# Be careful to allow flags like -f, --flag, etc. (e.g. "/-f").
# We reject all absolute Unix paths (including single-component ones like "/etc")
# to avoid access to potentially sensitive directories.
if arg.startswith("/") and not arg.startswith("/-"):
raise MCPValidationError(
f"Argument at index {i} contains absolute path: {arg[:50]}",
field="args",
)
# Check for Windows absolute paths
if len(arg) >= 2 and arg[1] == ":" and arg[0].isalpha():
raise MCPValidationError(
f"Argument at index {i} contains Windows absolute path: {arg[:50]}",
field="args",
)
# Check for home directory expansion
if arg.startswith("~/") or arg.startswith("~\\"):
raise MCPValidationError(
f"Argument at index {i} contains home directory reference: {arg[:50]}",
field="args",
)
# Check for dangerous extensions in the argument
arg_lower = arg.lower()
for ext in DANGEROUS_EXTENSIONS:
if arg_lower.endswith(ext):
raise MCPValidationError(
f"Argument at index {i} references potentially dangerous file type: {ext}",
field="args",
)
def validate_command_injection(args: List[str]) -> None:
"""
Validate arguments to prevent shell command injection.
Checks for:
- Shell metacharacters (; & | ` $ ( ) { } [ ] < > etc.)
- Command chaining patterns (&& || ;; etc.)
- Command substitution patterns ($() ``)
- Process substitution patterns (<() >())
Args:
args: List of command arguments to validate
Raises:
MCPValidationError: If any argument contains injection patterns
"""
if not args:
return
for i, arg in enumerate(args):
if not isinstance(arg, str):
continue
# Check for shell metacharacters
for char in SHELL_METACHARACTERS:
if char in arg:
raise MCPValidationError(
f"Argument at index {i} contains shell metacharacter '{char}': {arg[:50]}",
field="args",
)
# Check for command chaining patterns
for pattern in COMMAND_CHAINING_PATTERNS:
if pattern in arg:
raise MCPValidationError(
f"Argument at index {i} contains command chaining pattern '{pattern}': {arg[:50]}",
field="args",
)
def validate_environment_variables(env: Dict[str, str]) -> None:
"""
Validate environment variables to prevent dangerous modifications.
Checks for:
- Modifications to PATH and library path variables
- Null bytes in values
- Excessively long values
Args:
env: Dictionary of environment variables
Raises:
MCPValidationError: If any environment variable is dangerous
"""
if not env:
return
if not isinstance(env, dict):
raise MCPValidationError(
f"Environment variables must be a dictionary, got {type(env).__name__}",
field="env",
)
for key, value in env.items():
# Validate key
if not isinstance(key, str):
raise MCPValidationError(
f"Environment variable key must be a string, got {type(key).__name__}",
field="env",
)
# Check for dangerous environment variables
if key.upper() in DANGEROUS_ENV_VARS:
raise MCPValidationError(
f"Modification of environment variable '{key}' is not allowed for security reasons",
field="env",
)
# Validate value
if not isinstance(value, str):
raise MCPValidationError(
f"Environment variable value for '{key}' must be a string, got {type(value).__name__}",
field="env",
)
# Check for null bytes in value
if "\x00" in value:
raise MCPValidationError(
f"Environment variable '{key}' contains null byte",
field="env",
)
# Check for excessively long values
if len(value) > MAX_ARG_LENGTH * 10: # Allow longer env values
raise MCPValidationError(
f"Environment variable '{key}' value exceeds maximum length",
field="env",
)
def validate_url(url: str) -> None:
"""
Validate URL for SSE/HTTP transport.
Checks for:
- Valid URL format
- Allowed schemes (http, https)
- No credentials in URL
- No localhost/internal network access (optional, configurable)
Args:
url: The URL to validate
Raises:
MCPValidationError: If the URL is invalid or potentially dangerous
"""
if not url or not isinstance(url, str):
raise MCPValidationError("URL must be a non-empty string", field="url")
# Check for null bytes
if "\x00" in url:
raise MCPValidationError("URL contains null byte", field="url")
# Parse the URL
try:
parsed = urlparse(url)
except Exception as e:
raise MCPValidationError(f"Invalid URL format: {e}", field="url")
# Check scheme
if parsed.scheme not in ALLOWED_URL_SCHEMES:
raise MCPValidationError(
f"URL scheme '{parsed.scheme}' is not allowed. Allowed schemes: {', '.join(ALLOWED_URL_SCHEMES)}",
field="url",
)
# Check for credentials in URL (security risk)
if parsed.username or parsed.password:
raise MCPValidationError(
"URL should not contain credentials. Use headers for authentication instead.",
field="url",
)
# Check for valid host
if not parsed.netloc:
raise MCPValidationError("URL must have a valid host", field="url")
def validate_headers(headers: Dict[str, str]) -> None:
"""
Validate HTTP headers for potential injection attacks.
Args:
headers: Dictionary of HTTP headers
Raises:
MCPValidationError: If any header contains dangerous patterns
"""
if not headers:
return
if not isinstance(headers, dict):
raise MCPValidationError(
f"Headers must be a dictionary, got {type(headers).__name__}",
field="headers",
)
for key, value in headers.items():
# Validate key
if not isinstance(key, str):
raise MCPValidationError(
f"Header key must be a string, got {type(key).__name__}",
field="headers",
)
# Check for newlines in header name (HTTP header injection)
if "\n" in key or "\r" in key:
raise MCPValidationError(
f"Header name '{key[:20]}' contains newline character (potential HTTP header injection)",
field="headers",
)
# Validate value
if not isinstance(value, str):
raise MCPValidationError(
f"Header value for '{key}' must be a string, got {type(value).__name__}",
field="headers",
)
# Check for newlines in header value (HTTP header injection)
if "\n" in value or "\r" in value:
raise MCPValidationError(
f"Header value for '{key}' contains newline character (potential HTTP header injection)",
field="headers",
)
# Check for null bytes
if "\x00" in key or "\x00" in value:
raise MCPValidationError(
f"Header '{key}' contains null byte",
field="headers",
)

View File

@@ -352,9 +352,9 @@ class TestMCPEndpoint:
request_data = {
"transport": "stdio",
"command": "test_command",
"args": ["arg1", "arg2"],
"env": {"ENV_VAR": "value"},
"command": "node",
"args": ["server.js"],
"env": {"API_KEY": "test123"},
}
response = client.post("/api/mcp/server/metadata", json=request_data)
@@ -362,7 +362,7 @@ class TestMCPEndpoint:
assert response.status_code == 200
response_data = response.json()
assert response_data["transport"] == "stdio"
assert response_data["command"] == "test_command"
assert response_data["command"] == "node"
assert len(response_data["tools"]) == 1
@patch("src.server.app.load_mcp_tools")
@@ -375,7 +375,7 @@ class TestMCPEndpoint:
request_data = {
"transport": "stdio",
"command": "test_command",
"command": "node",
"timeout_seconds": 60,
}
@@ -424,9 +424,9 @@ class TestMCPEndpoint:
request_data = {
"transport": "stdio",
"command": "test_command",
"args": ["arg1", "arg2"],
"env": {"ENV_VAR": "value"},
"command": "node",
"args": ["server.js"],
"env": {"API_KEY": "test123"},
}
response = client.post("/api/mcp/server/metadata", json=request_data)
@@ -444,9 +444,9 @@ class TestMCPEndpoint:
):
request_data = {
"transport": "stdio",
"command": "test_command",
"args": ["arg1", "arg2"],
"env": {"ENV_VAR": "value"},
"command": "node",
"args": ["server.js"],
"env": {"API_KEY": "test123"},
}
response = client.post("/api/mcp/server/metadata", json=request_data)

View File

@@ -163,6 +163,6 @@ async def test_load_mcp_tools_exception_handling(
mock_stdio_client.return_value = MagicMock()
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") # Use await
await mcp_utils.load_mcp_tools(server_type="stdio", command="node") # Use await
assert exc.value.status_code == 500
assert "unexpected error" in exc.value.detail

View File

@@ -55,14 +55,14 @@ async def test_load_mcp_tools_stdio_success(
result = await mcp_utils.load_mcp_tools(
server_type="stdio",
command="echo",
args=["foo"],
env={"FOO": "BAR"},
command="node",
args=["server.js"],
env={"API_KEY": "test123"},
timeout_seconds=3,
)
assert result == ["toolA"]
mock_StdioServerParameters.assert_called_once_with(
command="echo", args=["foo"], env={"FOO": "BAR"}
command="node", args=["server.js"], env={"API_KEY": "test123"}
)
mock_stdio_client.assert_called_once_with(params)
mock_get_tools.assert_awaited_once_with(mock_client, 3)
@@ -165,7 +165,7 @@ async def test_load_mcp_tools_unsupported_type():
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="unknown")
assert exc.value.status_code == 400
assert "Unsupported server type" in exc.value.detail
assert "Invalid transport type" in exc.value.detail or "Unsupported server type" in exc.value.detail
@pytest.mark.asyncio
@@ -180,6 +180,6 @@ async def test_load_mcp_tools_exception_handling(
mock_stdio_client.return_value = MagicMock()
with pytest.raises(HTTPException) as exc:
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo")
await mcp_utils.load_mcp_tools(server_type="stdio", command="node")
assert exc.value.status_code == 500
assert "unexpected error" in exc.value.detail

View File

@@ -0,0 +1,450 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for MCP server configuration validators.
Tests cover:
- Command validation (allowlist)
- Argument validation (path traversal, command injection)
- Environment variable validation
- URL validation
- Header validation
- Full config validation
"""
import pytest
from src.server.mcp_validators import (
ALLOWED_COMMANDS,
MCPValidationError,
validate_args_for_local_file_access,
validate_command,
validate_command_injection,
validate_environment_variables,
validate_headers,
validate_mcp_server_config,
validate_url,
)
class TestValidateCommand:
"""Tests for validate_command function."""
def test_allowed_commands(self):
"""Test that all allowed commands pass validation."""
for cmd in ALLOWED_COMMANDS:
validate_command(cmd) # Should not raise
def test_allowed_command_with_path(self):
"""Test that commands with paths are validated by base name."""
validate_command("/usr/bin/python3")
validate_command("/usr/local/bin/node")
validate_command("C:\\Python\\python.exe")
def test_disallowed_command(self):
"""Test that disallowed commands raise an error."""
with pytest.raises(MCPValidationError) as exc_info:
validate_command("bash")
assert "not allowed" in exc_info.value.message
assert exc_info.value.field == "command"
def test_disallowed_dangerous_commands(self):
"""Test that dangerous commands are rejected."""
dangerous_commands = ["rm", "sudo", "chmod", "chown", "curl", "wget", "sh"]
for cmd in dangerous_commands:
with pytest.raises(MCPValidationError):
validate_command(cmd)
def test_empty_command(self):
"""Test that empty command raises an error."""
with pytest.raises(MCPValidationError):
validate_command("")
def test_none_command(self):
"""Test that None command raises an error."""
with pytest.raises(MCPValidationError):
validate_command(None)
class TestValidateArgsForLocalFileAccess:
"""Tests for validate_args_for_local_file_access function."""
def test_safe_args(self):
"""Test that safe arguments pass validation."""
safe_args = [
["--help"],
["-v", "--verbose"],
["package-name"],
["--config", "config.json"],
["run", "script.py"],
]
for args in safe_args:
validate_args_for_local_file_access(args) # Should not raise
def test_directory_traversal(self):
"""Test that directory traversal patterns are rejected."""
traversal_patterns = [
["../etc/passwd"],
["..\\windows\\system32"],
["../../secret"],
["foo/../bar/../../../etc/passwd"],
["foo/.."], # ".." at end after path separator
["bar\\.."], # ".." at end after Windows path separator
["path/to/foo/.."], # Longer path ending with ".."
]
for args in traversal_patterns:
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(args)
assert "traversal" in exc_info.value.message.lower()
def test_absolute_path_with_dangerous_extension(self):
"""Test that absolute paths with dangerous extensions are rejected."""
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["/etc/passwd.sh"])
def test_windows_absolute_path(self):
"""Test that Windows absolute paths are rejected."""
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["C:\\Windows\\system32"])
def test_home_directory_reference(self):
"""Test that home directory references are rejected."""
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["~/secrets"])
with pytest.raises(MCPValidationError):
validate_args_for_local_file_access(["~\\secrets"])
def test_null_byte(self):
"""Test that null bytes in arguments are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(["file\x00.txt"])
assert "null byte" in exc_info.value.message.lower()
def test_excessively_long_argument(self):
"""Test that excessively long arguments are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(["a" * 1001])
assert "maximum length" in exc_info.value.message.lower()
def test_dangerous_extensions(self):
"""Test that dangerous file extensions are rejected."""
dangerous_files = [
["script.sh"],
["binary.exe"],
["library.dll"],
["secret.env"],
["key.pem"],
]
for args in dangerous_files:
with pytest.raises(MCPValidationError) as exc_info:
validate_args_for_local_file_access(args)
assert "dangerous file type" in exc_info.value.message.lower()
def test_empty_args(self):
"""Test that empty args list passes validation."""
validate_args_for_local_file_access([])
validate_args_for_local_file_access(None)
class TestValidateCommandInjection:
"""Tests for validate_command_injection function."""
def test_safe_args(self):
"""Test that safe arguments pass validation."""
safe_args = [
["--help"],
["package-name"],
["@scope/package"],
["file.json"],
]
for args in safe_args:
validate_command_injection(args) # Should not raise
def test_shell_metacharacters(self):
"""Test that shell metacharacters are rejected."""
metachar_args = [
["foo; rm -rf /"],
["foo & bar"],
["foo | cat /etc/passwd"],
["$(whoami)"],
["`id`"],
["foo > /etc/passwd"],
["foo < /etc/passwd"],
["${PATH}"],
]
for args in metachar_args:
with pytest.raises(MCPValidationError) as exc_info:
validate_command_injection(args)
assert "args" == exc_info.value.field
def test_command_chaining(self):
"""Test that command chaining patterns are rejected."""
chaining_args = [
["foo && bar"],
["foo || bar"],
["foo;; bar"],
["foo >> output"],
["foo << input"],
]
for args in chaining_args:
with pytest.raises(MCPValidationError):
validate_command_injection(args)
def test_backtick_injection(self):
"""Test that backtick command substitution is rejected."""
with pytest.raises(MCPValidationError):
validate_command_injection(["`whoami`"])
def test_process_substitution(self):
"""Test that process substitution is rejected."""
with pytest.raises(MCPValidationError):
validate_command_injection(["<(cat /etc/passwd)"])
with pytest.raises(MCPValidationError):
validate_command_injection([">(tee /tmp/out)"])
class TestValidateEnvironmentVariables:
"""Tests for validate_environment_variables function."""
def test_safe_env_vars(self):
"""Test that safe environment variables pass validation."""
safe_env = {
"API_KEY": "secret123",
"DEBUG": "true",
"MY_VARIABLE": "value",
}
validate_environment_variables(safe_env) # Should not raise
def test_dangerous_env_vars(self):
"""Test that dangerous environment variables are rejected."""
dangerous_vars = [
{"PATH": "/malicious/path"},
{"LD_LIBRARY_PATH": "/malicious/lib"},
{"DYLD_LIBRARY_PATH": "/malicious/lib"},
{"LD_PRELOAD": "/malicious/lib.so"},
{"PYTHONPATH": "/malicious/python"},
{"NODE_PATH": "/malicious/node"},
]
for env in dangerous_vars:
with pytest.raises(MCPValidationError) as exc_info:
validate_environment_variables(env)
assert "not allowed" in exc_info.value.message.lower()
def test_null_byte_in_value(self):
"""Test that null bytes in values are rejected."""
with pytest.raises(MCPValidationError):
validate_environment_variables({"KEY": "value\x00malicious"})
def test_empty_env(self):
"""Test that empty env dict passes validation."""
validate_environment_variables({})
validate_environment_variables(None)
class TestValidateUrl:
"""Tests for validate_url function."""
def test_valid_urls(self):
"""Test that valid URLs pass validation."""
valid_urls = [
"http://localhost:3000",
"https://api.example.com",
"http://192.168.1.1:8080/api",
"https://mcp.example.com/sse",
]
for url in valid_urls:
validate_url(url) # Should not raise
def test_invalid_scheme(self):
"""Test that invalid URL schemes are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_url("ftp://example.com")
assert "scheme" in exc_info.value.message.lower()
with pytest.raises(MCPValidationError):
validate_url("file:///etc/passwd")
def test_credentials_in_url(self):
"""Test that URLs with credentials are rejected."""
with pytest.raises(MCPValidationError) as exc_info:
validate_url("https://user:pass@example.com")
assert "credentials" in exc_info.value.message.lower()
def test_null_byte_in_url(self):
"""Test that null bytes in URL are rejected."""
with pytest.raises(MCPValidationError):
validate_url("https://example.com\x00/malicious")
def test_empty_url(self):
"""Test that empty URL raises an error."""
with pytest.raises(MCPValidationError):
validate_url("")
def test_no_host(self):
"""Test that URL without host raises an error."""
with pytest.raises(MCPValidationError):
validate_url("http:///path")
class TestValidateHeaders:
"""Tests for validate_headers function."""
def test_valid_headers(self):
"""Test that valid headers pass validation."""
valid_headers = {
"Authorization": "Bearer token123",
"Content-Type": "application/json",
"X-Custom-Header": "value",
}
validate_headers(valid_headers) # Should not raise
def test_newline_in_header_name(self):
"""Test that newlines in header names are rejected (HTTP header injection)."""
with pytest.raises(MCPValidationError) as exc_info:
validate_headers({"X-Bad\nHeader": "value"})
assert "newline" in exc_info.value.message.lower()
def test_newline_in_header_value(self):
"""Test that newlines in header values are rejected (HTTP header injection)."""
with pytest.raises(MCPValidationError):
validate_headers({"X-Header": "value\r\nX-Injected: malicious"})
def test_null_byte_in_header(self):
"""Test that null bytes in headers are rejected."""
with pytest.raises(MCPValidationError):
validate_headers({"X-Header": "value\x00"})
def test_empty_headers(self):
"""Test that empty headers dict passes validation."""
validate_headers({})
validate_headers(None)
class TestValidateMCPServerConfig:
"""Tests for the main validate_mcp_server_config function."""
def test_valid_stdio_config(self):
"""Test valid stdio configuration."""
validate_mcp_server_config(
transport="stdio",
command="npx",
args=["@modelcontextprotocol/server-filesystem"],
env={"API_KEY": "secret"},
) # Should not raise
def test_valid_sse_config(self):
"""Test valid SSE configuration."""
validate_mcp_server_config(
transport="sse",
url="https://api.example.com/sse",
headers={"Authorization": "Bearer token"},
) # Should not raise
def test_valid_http_config(self):
"""Test valid streamable_http configuration."""
validate_mcp_server_config(
transport="streamable_http",
url="https://api.example.com/mcp",
) # Should not raise
def test_invalid_transport(self):
"""Test that invalid transport type raises an error."""
with pytest.raises(MCPValidationError) as exc_info:
validate_mcp_server_config(transport="invalid")
assert "Invalid transport type" in exc_info.value.message
def test_combined_validation_errors(self):
"""Test that multiple validation errors are combined."""
with pytest.raises(MCPValidationError) as exc_info:
validate_mcp_server_config(
transport="stdio",
command="bash", # Not allowed
args=["../etc/passwd"], # Path traversal
env={"PATH": "/malicious"}, # Dangerous env var
)
# All errors should be combined
assert "not allowed" in exc_info.value.message
assert "traversal" in exc_info.value.message.lower()
def test_non_strict_mode(self):
"""Test that non-strict mode logs warnings instead of raising."""
# Should not raise, but would log warnings
validate_mcp_server_config(
transport="stdio",
command="bash",
strict=False,
)
def test_stdio_with_dangerous_args(self):
"""Test stdio config with command injection attempt."""
with pytest.raises(MCPValidationError):
validate_mcp_server_config(
transport="stdio",
command="node",
args=["script.js; rm -rf /"],
)
def test_sse_with_invalid_url(self):
"""Test SSE config with invalid URL."""
with pytest.raises(MCPValidationError):
validate_mcp_server_config(
transport="sse",
url="ftp://example.com",
)
class TestMCPServerMetadataRequest:
"""Tests for Pydantic model validation."""
def test_valid_request(self):
"""Test that valid request passes validation."""
from src.server.mcp_request import MCPServerMetadataRequest
request = MCPServerMetadataRequest(
transport="stdio",
command="npx",
args=["@modelcontextprotocol/server-filesystem"],
)
assert request.transport == "stdio"
assert request.command == "npx"
def test_invalid_command_raises_validation_error(self):
"""Test that invalid command raises Pydantic ValidationError."""
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest
with pytest.raises(ValidationError) as exc_info:
MCPServerMetadataRequest(
transport="stdio",
command="bash",
)
assert "not allowed" in str(exc_info.value).lower()
def test_command_injection_raises_validation_error(self):
"""Test that command injection raises Pydantic ValidationError."""
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest
with pytest.raises(ValidationError):
MCPServerMetadataRequest(
transport="stdio",
command="node",
args=["script.js; rm -rf /"],
)
def test_invalid_url_raises_validation_error(self):
"""Test that invalid URL raises Pydantic ValidationError."""
from pydantic import ValidationError
from src.server.mcp_request import MCPServerMetadataRequest
with pytest.raises(ValidationError):
MCPServerMetadataRequest(
transport="sse",
url="ftp://example.com",
)