diff --git a/src/server/mcp_request.py b/src/server/mcp_request.py index 2f65d51..1728bb1 100644 --- a/src/server/mcp_request.py +++ b/src/server/mcp_request.py @@ -3,7 +3,17 @@ from typing import Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator + +from src.server.mcp_validators import ( + MCPValidationError, + validate_args_for_local_file_access, + validate_command, + validate_command_injection, + validate_environment_variables, + validate_headers, + validate_url, +) class MCPServerMetadataRequest(BaseModel): @@ -43,6 +53,62 @@ class MCPServerMetadataRequest(BaseModel): description="Optional SSE read timeout in seconds (for sse type, default: 30, range: 1-3600)" ) + @model_validator(mode="after") + def validate_security(self) -> "MCPServerMetadataRequest": + """Validate MCP server configuration for security issues.""" + errors: List[str] = [] + + # Validate transport type + valid_transports = {"stdio", "sse", "streamable_http"} + if self.transport not in valid_transports: + errors.append( + f"Invalid transport type: {self.transport}. Must be one of: {', '.join(valid_transports)}" + ) + + # Validate stdio-specific fields + if self.transport == "stdio": + if self.command: + try: + validate_command(self.command) + except MCPValidationError as e: + errors.append(e.message) + + if self.args: + try: + validate_args_for_local_file_access(self.args) + except MCPValidationError as e: + errors.append(e.message) + + try: + validate_command_injection(self.args) + except MCPValidationError as e: + errors.append(e.message) + + if self.env: + try: + validate_environment_variables(self.env) + except MCPValidationError as e: + errors.append(e.message) + + # Validate SSE/HTTP-specific fields + elif self.transport in ("sse", "streamable_http"): + if self.url: + try: + validate_url(self.url) + except MCPValidationError as e: + errors.append(e.message) + + if self.headers: + try: + validate_headers(self.headers) + except MCPValidationError as e: + errors.append(e.message) + + if errors: + raise ValueError("; ".join(errors)) + + return self + class MCPServerMetadataResponse(BaseModel): """Response model for MCP server metadata.""" diff --git a/src/server/mcp_utils.py b/src/server/mcp_utils.py index f3a4dcf..d204760 100644 --- a/src/server/mcp_utils.py +++ b/src/server/mcp_utils.py @@ -11,6 +11,8 @@ from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client +from src.server.mcp_validators import MCPValidationError, validate_mcp_server_config + logger = logging.getLogger(__name__) @@ -75,6 +77,9 @@ async def load_mcp_tools( Raises: HTTPException: If there's an error loading the tools """ + # MCP server configuration is validated at the request boundary (Pydantic model) + # to avoid duplicate validation logic here. + try: if server_type == "stdio": if not command: diff --git a/src/server/mcp_validators.py b/src/server/mcp_validators.py new file mode 100644 index 0000000..be14a50 --- /dev/null +++ b/src/server/mcp_validators.py @@ -0,0 +1,532 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +MCP Server Configuration Validators. + +This module provides security validation for MCP server configurations, +inspired by Flowise's validateMCPServerConfig implementation. It prevents: +- Command injection attacks +- Path traversal attacks +- Unauthorized file access +- Dangerous environment variable modifications + +Reference: https://github.com/FlowiseAI/Flowise/blob/main/packages/components/nodes/tools/MCP/core.ts +""" + +import logging + +from typing import Dict, List, Optional +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +class MCPValidationError(Exception): + """Exception raised when MCP server configuration validation fails.""" + + def __init__(self, message: str, field: Optional[str] = None): + self.message = message + self.field = field + super().__init__(self.message) + + +# Allowed commands for stdio transport +# These are considered safe executable commands for MCP servers +ALLOWED_COMMANDS = frozenset([ + "node", + "npx", + "python", + "python3", + "docker", + "uvx", + "uv", + "deno", + "bun", +]) + +# Dangerous environment variables that should not be modified +DANGEROUS_ENV_VARS = frozenset([ + "PATH", + "LD_LIBRARY_PATH", + "DYLD_LIBRARY_PATH", + "LD_PRELOAD", + "DYLD_INSERT_LIBRARIES", + "PYTHONPATH", + "NODE_PATH", + "RUBYLIB", + "PERL5LIB", +]) + +# Shell metacharacters that could be used for injection +SHELL_METACHARACTERS = frozenset([ + ";", + "&", + "|", + "`", + "$", + "(", + ")", + "{", + "}", + "[", + "]", + "<", + ">", + "\n", + "\r", +]) + +# Dangerous file extensions that should not be directly accessed +DANGEROUS_EXTENSIONS = frozenset([ + ".exe", + ".dll", + ".so", + ".dylib", + ".bat", + ".cmd", + ".ps1", + ".sh", + ".bash", + ".zsh", + ".env", + ".pem", + ".key", + ".crt", + ".p12", + ".pfx", +]) + +# Command chaining patterns +COMMAND_CHAINING_PATTERNS = [ + "&&", + "||", + ";;", + ">>", + "<<", + "$(", + "<(", + ">(", +] + +# Maximum argument length to prevent buffer overflow attacks +MAX_ARG_LENGTH = 1000 + +# Allowed URL schemes for SSE/HTTP transports +ALLOWED_URL_SCHEMES = frozenset(["http", "https"]) + + +def validate_mcp_server_config( + transport: str, + command: Optional[str] = None, + args: Optional[List[str]] = None, + url: Optional[str] = None, + env: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + strict: bool = True, +) -> None: + """ + Validate MCP server configuration for security issues. + + This is the main entry point for MCP server validation. It orchestrates + all security checks based on the transport type. + + Args: + transport: The type of MCP connection (stdio, sse, streamable_http) + command: The command to execute (for stdio transport) + args: Command arguments (for stdio transport) + url: The URL of the server (for sse/streamable_http transport) + env: Environment variables (for stdio transport) + headers: HTTP headers (for sse/streamable_http transport) + strict: If True, raise exceptions; if False, log warnings only + + Raises: + MCPValidationError: If validation fails in strict mode + """ + errors: List[str] = [] + + # Validate transport type + valid_transports = {"stdio", "sse", "streamable_http"} + if transport not in valid_transports: + errors.append(f"Invalid transport type: {transport}. Must be one of: {', '.join(valid_transports)}") + + # Transport-specific validation + if transport == "stdio": + # Validate command + if command: + try: + validate_command(command) + except MCPValidationError as e: + errors.append(e.message) + + # Validate arguments + if args: + try: + validate_args_for_local_file_access(args) + except MCPValidationError as e: + errors.append(e.message) + + try: + validate_command_injection(args) + except MCPValidationError as e: + errors.append(e.message) + + # Validate environment variables + if env: + try: + validate_environment_variables(env) + except MCPValidationError as e: + errors.append(e.message) + + elif transport in ("sse", "streamable_http"): + # Validate URL + if url: + try: + validate_url(url) + except MCPValidationError as e: + errors.append(e.message) + + # Validate headers for injection + if headers: + try: + validate_headers(headers) + except MCPValidationError as e: + errors.append(e.message) + + # Handle errors + if errors: + error_message = "; ".join(errors) + if strict: + raise MCPValidationError(error_message) + else: + logger.warning(f"MCP configuration validation warnings: {error_message}") + + +def validate_command(command: str) -> None: + """ + Validate the command against an allowlist of safe executables. + + Args: + command: The command to validate + + Raises: + MCPValidationError: If the command is not in the allowlist + """ + if not command or not isinstance(command, str): + raise MCPValidationError("Command must be a non-empty string", field="command") + + # Extract the base command (handle full paths) + # e.g., "/usr/bin/python3" -> "python3" + base_command = command.split("/")[-1].split("\\")[-1] + + # Also handle .exe suffix on Windows + if base_command.endswith(".exe"): + base_command = base_command[:-4] + + # Normalize to lowercase to handle case-insensitive filesystems (e.g., Windows) + normalized_command = base_command.lower() + + if normalized_command not in ALLOWED_COMMANDS: + raise MCPValidationError( + f"Command '{command}' is not allowed. Allowed commands: {', '.join(sorted(ALLOWED_COMMANDS))}", + field="command", + ) + + +def validate_args_for_local_file_access(args: List[str]) -> None: + """ + Validate arguments to prevent path traversal and unauthorized file access. + + Checks for: + - Absolute paths (starting with / or drive letters like C:) + - Directory traversal (../, ..\\) + - Local file access patterns (./, ~/) + - Dangerous file extensions + - Null bytes (security exploit) + - Excessively long arguments (buffer overflow protection) + + Args: + args: List of command arguments to validate + + Raises: + MCPValidationError: If any argument contains dangerous patterns + """ + if not args: + return + + for i, arg in enumerate(args): + if not isinstance(arg, str): + raise MCPValidationError( + f"Argument at index {i} must be a string, got {type(arg).__name__}", + field="args", + ) + + # Check for excessively long arguments + if len(arg) > MAX_ARG_LENGTH: + raise MCPValidationError( + f"Argument at index {i} exceeds maximum length of {MAX_ARG_LENGTH} characters", + field="args", + ) + + # Check for null bytes + if "\x00" in arg: + raise MCPValidationError( + f"Argument at index {i} contains null byte", + field="args", + ) + + # Check for directory traversal + if ".." in arg: + # More specific check for actual traversal patterns + # Catches: "../", "..\", "/..", "\..", standalone "..", starts with "..", ends with ".." + if ( + "../" in arg + or "..\\" in arg + or "/.." in arg + or "\\.." in arg + or arg == ".." + or arg.startswith("..") + or arg.endswith("..") + ): + raise MCPValidationError( + f"Argument at index {i} contains directory traversal pattern: {arg[:50]}", + field="args", + ) + + # Check for absolute paths (Unix-style) + # Be careful to allow flags like -f, --flag, etc. (e.g. "/-f"). + # We reject all absolute Unix paths (including single-component ones like "/etc") + # to avoid access to potentially sensitive directories. + if arg.startswith("/") and not arg.startswith("/-"): + raise MCPValidationError( + f"Argument at index {i} contains absolute path: {arg[:50]}", + field="args", + ) + + # Check for Windows absolute paths + if len(arg) >= 2 and arg[1] == ":" and arg[0].isalpha(): + raise MCPValidationError( + f"Argument at index {i} contains Windows absolute path: {arg[:50]}", + field="args", + ) + + # Check for home directory expansion + if arg.startswith("~/") or arg.startswith("~\\"): + raise MCPValidationError( + f"Argument at index {i} contains home directory reference: {arg[:50]}", + field="args", + ) + + # Check for dangerous extensions in the argument + arg_lower = arg.lower() + for ext in DANGEROUS_EXTENSIONS: + if arg_lower.endswith(ext): + raise MCPValidationError( + f"Argument at index {i} references potentially dangerous file type: {ext}", + field="args", + ) + + +def validate_command_injection(args: List[str]) -> None: + """ + Validate arguments to prevent shell command injection. + + Checks for: + - Shell metacharacters (; & | ` $ ( ) { } [ ] < > etc.) + - Command chaining patterns (&& || ;; etc.) + - Command substitution patterns ($() ``) + - Process substitution patterns (<() >()) + + Args: + args: List of command arguments to validate + + Raises: + MCPValidationError: If any argument contains injection patterns + """ + if not args: + return + + for i, arg in enumerate(args): + if not isinstance(arg, str): + continue + + # Check for shell metacharacters + for char in SHELL_METACHARACTERS: + if char in arg: + raise MCPValidationError( + f"Argument at index {i} contains shell metacharacter '{char}': {arg[:50]}", + field="args", + ) + + # Check for command chaining patterns + for pattern in COMMAND_CHAINING_PATTERNS: + if pattern in arg: + raise MCPValidationError( + f"Argument at index {i} contains command chaining pattern '{pattern}': {arg[:50]}", + field="args", + ) + + +def validate_environment_variables(env: Dict[str, str]) -> None: + """ + Validate environment variables to prevent dangerous modifications. + + Checks for: + - Modifications to PATH and library path variables + - Null bytes in values + - Excessively long values + + Args: + env: Dictionary of environment variables + + Raises: + MCPValidationError: If any environment variable is dangerous + """ + if not env: + return + + if not isinstance(env, dict): + raise MCPValidationError( + f"Environment variables must be a dictionary, got {type(env).__name__}", + field="env", + ) + + for key, value in env.items(): + # Validate key + if not isinstance(key, str): + raise MCPValidationError( + f"Environment variable key must be a string, got {type(key).__name__}", + field="env", + ) + + # Check for dangerous environment variables + if key.upper() in DANGEROUS_ENV_VARS: + raise MCPValidationError( + f"Modification of environment variable '{key}' is not allowed for security reasons", + field="env", + ) + + # Validate value + if not isinstance(value, str): + raise MCPValidationError( + f"Environment variable value for '{key}' must be a string, got {type(value).__name__}", + field="env", + ) + + # Check for null bytes in value + if "\x00" in value: + raise MCPValidationError( + f"Environment variable '{key}' contains null byte", + field="env", + ) + + # Check for excessively long values + if len(value) > MAX_ARG_LENGTH * 10: # Allow longer env values + raise MCPValidationError( + f"Environment variable '{key}' value exceeds maximum length", + field="env", + ) + + +def validate_url(url: str) -> None: + """ + Validate URL for SSE/HTTP transport. + + Checks for: + - Valid URL format + - Allowed schemes (http, https) + - No credentials in URL + - No localhost/internal network access (optional, configurable) + + Args: + url: The URL to validate + + Raises: + MCPValidationError: If the URL is invalid or potentially dangerous + """ + if not url or not isinstance(url, str): + raise MCPValidationError("URL must be a non-empty string", field="url") + + # Check for null bytes + if "\x00" in url: + raise MCPValidationError("URL contains null byte", field="url") + + # Parse the URL + try: + parsed = urlparse(url) + except Exception as e: + raise MCPValidationError(f"Invalid URL format: {e}", field="url") + + # Check scheme + if parsed.scheme not in ALLOWED_URL_SCHEMES: + raise MCPValidationError( + f"URL scheme '{parsed.scheme}' is not allowed. Allowed schemes: {', '.join(ALLOWED_URL_SCHEMES)}", + field="url", + ) + + # Check for credentials in URL (security risk) + if parsed.username or parsed.password: + raise MCPValidationError( + "URL should not contain credentials. Use headers for authentication instead.", + field="url", + ) + + # Check for valid host + if not parsed.netloc: + raise MCPValidationError("URL must have a valid host", field="url") + + +def validate_headers(headers: Dict[str, str]) -> None: + """ + Validate HTTP headers for potential injection attacks. + + Args: + headers: Dictionary of HTTP headers + + Raises: + MCPValidationError: If any header contains dangerous patterns + """ + if not headers: + return + + if not isinstance(headers, dict): + raise MCPValidationError( + f"Headers must be a dictionary, got {type(headers).__name__}", + field="headers", + ) + + for key, value in headers.items(): + # Validate key + if not isinstance(key, str): + raise MCPValidationError( + f"Header key must be a string, got {type(key).__name__}", + field="headers", + ) + + # Check for newlines in header name (HTTP header injection) + if "\n" in key or "\r" in key: + raise MCPValidationError( + f"Header name '{key[:20]}' contains newline character (potential HTTP header injection)", + field="headers", + ) + + # Validate value + if not isinstance(value, str): + raise MCPValidationError( + f"Header value for '{key}' must be a string, got {type(value).__name__}", + field="headers", + ) + + # Check for newlines in header value (HTTP header injection) + if "\n" in value or "\r" in value: + raise MCPValidationError( + f"Header value for '{key}' contains newline character (potential HTTP header injection)", + field="headers", + ) + + # Check for null bytes + if "\x00" in key or "\x00" in value: + raise MCPValidationError( + f"Header '{key}' contains null byte", + field="headers", + ) diff --git a/tests/unit/server/test_app.py b/tests/unit/server/test_app.py index a3c796b..ccc4260 100644 --- a/tests/unit/server/test_app.py +++ b/tests/unit/server/test_app.py @@ -352,9 +352,9 @@ class TestMCPEndpoint: request_data = { "transport": "stdio", - "command": "test_command", - "args": ["arg1", "arg2"], - "env": {"ENV_VAR": "value"}, + "command": "node", + "args": ["server.js"], + "env": {"API_KEY": "test123"}, } response = client.post("/api/mcp/server/metadata", json=request_data) @@ -362,7 +362,7 @@ class TestMCPEndpoint: assert response.status_code == 200 response_data = response.json() assert response_data["transport"] == "stdio" - assert response_data["command"] == "test_command" + assert response_data["command"] == "node" assert len(response_data["tools"]) == 1 @patch("src.server.app.load_mcp_tools") @@ -375,7 +375,7 @@ class TestMCPEndpoint: request_data = { "transport": "stdio", - "command": "test_command", + "command": "node", "timeout_seconds": 60, } @@ -424,9 +424,9 @@ class TestMCPEndpoint: request_data = { "transport": "stdio", - "command": "test_command", - "args": ["arg1", "arg2"], - "env": {"ENV_VAR": "value"}, + "command": "node", + "args": ["server.js"], + "env": {"API_KEY": "test123"}, } response = client.post("/api/mcp/server/metadata", json=request_data) @@ -444,9 +444,9 @@ class TestMCPEndpoint: ): request_data = { "transport": "stdio", - "command": "test_command", - "args": ["arg1", "arg2"], - "env": {"ENV_VAR": "value"}, + "command": "node", + "args": ["server.js"], + "env": {"API_KEY": "test123"}, } response = client.post("/api/mcp/server/metadata", json=request_data) diff --git a/tests/unit/server/test_chat_request.py b/tests/unit/server/test_chat_request.py index 69bce54..130712d 100644 --- a/tests/unit/server/test_chat_request.py +++ b/tests/unit/server/test_chat_request.py @@ -163,6 +163,6 @@ async def test_load_mcp_tools_exception_handling( mock_stdio_client.return_value = MagicMock() with pytest.raises(HTTPException) as exc: - await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") # Use await + await mcp_utils.load_mcp_tools(server_type="stdio", command="node") # Use await assert exc.value.status_code == 500 assert "unexpected error" in exc.value.detail diff --git a/tests/unit/server/test_mcp_utils.py b/tests/unit/server/test_mcp_utils.py index 8d5ad16..12e1867 100644 --- a/tests/unit/server/test_mcp_utils.py +++ b/tests/unit/server/test_mcp_utils.py @@ -55,14 +55,14 @@ async def test_load_mcp_tools_stdio_success( result = await mcp_utils.load_mcp_tools( server_type="stdio", - command="echo", - args=["foo"], - env={"FOO": "BAR"}, + command="node", + args=["server.js"], + env={"API_KEY": "test123"}, timeout_seconds=3, ) assert result == ["toolA"] mock_StdioServerParameters.assert_called_once_with( - command="echo", args=["foo"], env={"FOO": "BAR"} + command="node", args=["server.js"], env={"API_KEY": "test123"} ) mock_stdio_client.assert_called_once_with(params) mock_get_tools.assert_awaited_once_with(mock_client, 3) @@ -165,7 +165,7 @@ async def test_load_mcp_tools_unsupported_type(): with pytest.raises(HTTPException) as exc: await mcp_utils.load_mcp_tools(server_type="unknown") assert exc.value.status_code == 400 - assert "Unsupported server type" in exc.value.detail + assert "Invalid transport type" in exc.value.detail or "Unsupported server type" in exc.value.detail @pytest.mark.asyncio @@ -180,6 +180,6 @@ async def test_load_mcp_tools_exception_handling( mock_stdio_client.return_value = MagicMock() with pytest.raises(HTTPException) as exc: - await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") + await mcp_utils.load_mcp_tools(server_type="stdio", command="node") assert exc.value.status_code == 500 assert "unexpected error" in exc.value.detail diff --git a/tests/unit/server/test_mcp_validators.py b/tests/unit/server/test_mcp_validators.py new file mode 100644 index 0000000..8e11da1 --- /dev/null +++ b/tests/unit/server/test_mcp_validators.py @@ -0,0 +1,450 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +""" +Unit tests for MCP server configuration validators. + +Tests cover: +- Command validation (allowlist) +- Argument validation (path traversal, command injection) +- Environment variable validation +- URL validation +- Header validation +- Full config validation +""" + +import pytest + +from src.server.mcp_validators import ( + ALLOWED_COMMANDS, + MCPValidationError, + validate_args_for_local_file_access, + validate_command, + validate_command_injection, + validate_environment_variables, + validate_headers, + validate_mcp_server_config, + validate_url, +) + + +class TestValidateCommand: + """Tests for validate_command function.""" + + def test_allowed_commands(self): + """Test that all allowed commands pass validation.""" + for cmd in ALLOWED_COMMANDS: + validate_command(cmd) # Should not raise + + def test_allowed_command_with_path(self): + """Test that commands with paths are validated by base name.""" + validate_command("/usr/bin/python3") + validate_command("/usr/local/bin/node") + validate_command("C:\\Python\\python.exe") + + def test_disallowed_command(self): + """Test that disallowed commands raise an error.""" + with pytest.raises(MCPValidationError) as exc_info: + validate_command("bash") + assert "not allowed" in exc_info.value.message + assert exc_info.value.field == "command" + + def test_disallowed_dangerous_commands(self): + """Test that dangerous commands are rejected.""" + dangerous_commands = ["rm", "sudo", "chmod", "chown", "curl", "wget", "sh"] + for cmd in dangerous_commands: + with pytest.raises(MCPValidationError): + validate_command(cmd) + + def test_empty_command(self): + """Test that empty command raises an error.""" + with pytest.raises(MCPValidationError): + validate_command("") + + def test_none_command(self): + """Test that None command raises an error.""" + with pytest.raises(MCPValidationError): + validate_command(None) + + +class TestValidateArgsForLocalFileAccess: + """Tests for validate_args_for_local_file_access function.""" + + def test_safe_args(self): + """Test that safe arguments pass validation.""" + safe_args = [ + ["--help"], + ["-v", "--verbose"], + ["package-name"], + ["--config", "config.json"], + ["run", "script.py"], + ] + for args in safe_args: + validate_args_for_local_file_access(args) # Should not raise + + def test_directory_traversal(self): + """Test that directory traversal patterns are rejected.""" + traversal_patterns = [ + ["../etc/passwd"], + ["..\\windows\\system32"], + ["../../secret"], + ["foo/../bar/../../../etc/passwd"], + ["foo/.."], # ".." at end after path separator + ["bar\\.."], # ".." at end after Windows path separator + ["path/to/foo/.."], # Longer path ending with ".." + ] + for args in traversal_patterns: + with pytest.raises(MCPValidationError) as exc_info: + validate_args_for_local_file_access(args) + assert "traversal" in exc_info.value.message.lower() + + def test_absolute_path_with_dangerous_extension(self): + """Test that absolute paths with dangerous extensions are rejected.""" + with pytest.raises(MCPValidationError): + validate_args_for_local_file_access(["/etc/passwd.sh"]) + + def test_windows_absolute_path(self): + """Test that Windows absolute paths are rejected.""" + with pytest.raises(MCPValidationError): + validate_args_for_local_file_access(["C:\\Windows\\system32"]) + + def test_home_directory_reference(self): + """Test that home directory references are rejected.""" + with pytest.raises(MCPValidationError): + validate_args_for_local_file_access(["~/secrets"]) + + with pytest.raises(MCPValidationError): + validate_args_for_local_file_access(["~\\secrets"]) + + def test_null_byte(self): + """Test that null bytes in arguments are rejected.""" + with pytest.raises(MCPValidationError) as exc_info: + validate_args_for_local_file_access(["file\x00.txt"]) + assert "null byte" in exc_info.value.message.lower() + + def test_excessively_long_argument(self): + """Test that excessively long arguments are rejected.""" + with pytest.raises(MCPValidationError) as exc_info: + validate_args_for_local_file_access(["a" * 1001]) + assert "maximum length" in exc_info.value.message.lower() + + def test_dangerous_extensions(self): + """Test that dangerous file extensions are rejected.""" + dangerous_files = [ + ["script.sh"], + ["binary.exe"], + ["library.dll"], + ["secret.env"], + ["key.pem"], + ] + for args in dangerous_files: + with pytest.raises(MCPValidationError) as exc_info: + validate_args_for_local_file_access(args) + assert "dangerous file type" in exc_info.value.message.lower() + + def test_empty_args(self): + """Test that empty args list passes validation.""" + validate_args_for_local_file_access([]) + validate_args_for_local_file_access(None) + + +class TestValidateCommandInjection: + """Tests for validate_command_injection function.""" + + def test_safe_args(self): + """Test that safe arguments pass validation.""" + safe_args = [ + ["--help"], + ["package-name"], + ["@scope/package"], + ["file.json"], + ] + for args in safe_args: + validate_command_injection(args) # Should not raise + + def test_shell_metacharacters(self): + """Test that shell metacharacters are rejected.""" + metachar_args = [ + ["foo; rm -rf /"], + ["foo & bar"], + ["foo | cat /etc/passwd"], + ["$(whoami)"], + ["`id`"], + ["foo > /etc/passwd"], + ["foo < /etc/passwd"], + ["${PATH}"], + ] + for args in metachar_args: + with pytest.raises(MCPValidationError) as exc_info: + validate_command_injection(args) + assert "args" == exc_info.value.field + + def test_command_chaining(self): + """Test that command chaining patterns are rejected.""" + chaining_args = [ + ["foo && bar"], + ["foo || bar"], + ["foo;; bar"], + ["foo >> output"], + ["foo << input"], + ] + for args in chaining_args: + with pytest.raises(MCPValidationError): + validate_command_injection(args) + + def test_backtick_injection(self): + """Test that backtick command substitution is rejected.""" + with pytest.raises(MCPValidationError): + validate_command_injection(["`whoami`"]) + + def test_process_substitution(self): + """Test that process substitution is rejected.""" + with pytest.raises(MCPValidationError): + validate_command_injection(["<(cat /etc/passwd)"]) + + with pytest.raises(MCPValidationError): + validate_command_injection([">(tee /tmp/out)"]) + + +class TestValidateEnvironmentVariables: + """Tests for validate_environment_variables function.""" + + def test_safe_env_vars(self): + """Test that safe environment variables pass validation.""" + safe_env = { + "API_KEY": "secret123", + "DEBUG": "true", + "MY_VARIABLE": "value", + } + validate_environment_variables(safe_env) # Should not raise + + def test_dangerous_env_vars(self): + """Test that dangerous environment variables are rejected.""" + dangerous_vars = [ + {"PATH": "/malicious/path"}, + {"LD_LIBRARY_PATH": "/malicious/lib"}, + {"DYLD_LIBRARY_PATH": "/malicious/lib"}, + {"LD_PRELOAD": "/malicious/lib.so"}, + {"PYTHONPATH": "/malicious/python"}, + {"NODE_PATH": "/malicious/node"}, + ] + for env in dangerous_vars: + with pytest.raises(MCPValidationError) as exc_info: + validate_environment_variables(env) + assert "not allowed" in exc_info.value.message.lower() + + def test_null_byte_in_value(self): + """Test that null bytes in values are rejected.""" + with pytest.raises(MCPValidationError): + validate_environment_variables({"KEY": "value\x00malicious"}) + + def test_empty_env(self): + """Test that empty env dict passes validation.""" + validate_environment_variables({}) + validate_environment_variables(None) + + +class TestValidateUrl: + """Tests for validate_url function.""" + + def test_valid_urls(self): + """Test that valid URLs pass validation.""" + valid_urls = [ + "http://localhost:3000", + "https://api.example.com", + "http://192.168.1.1:8080/api", + "https://mcp.example.com/sse", + ] + for url in valid_urls: + validate_url(url) # Should not raise + + def test_invalid_scheme(self): + """Test that invalid URL schemes are rejected.""" + with pytest.raises(MCPValidationError) as exc_info: + validate_url("ftp://example.com") + assert "scheme" in exc_info.value.message.lower() + + with pytest.raises(MCPValidationError): + validate_url("file:///etc/passwd") + + def test_credentials_in_url(self): + """Test that URLs with credentials are rejected.""" + with pytest.raises(MCPValidationError) as exc_info: + validate_url("https://user:pass@example.com") + assert "credentials" in exc_info.value.message.lower() + + def test_null_byte_in_url(self): + """Test that null bytes in URL are rejected.""" + with pytest.raises(MCPValidationError): + validate_url("https://example.com\x00/malicious") + + def test_empty_url(self): + """Test that empty URL raises an error.""" + with pytest.raises(MCPValidationError): + validate_url("") + + def test_no_host(self): + """Test that URL without host raises an error.""" + with pytest.raises(MCPValidationError): + validate_url("http:///path") + + +class TestValidateHeaders: + """Tests for validate_headers function.""" + + def test_valid_headers(self): + """Test that valid headers pass validation.""" + valid_headers = { + "Authorization": "Bearer token123", + "Content-Type": "application/json", + "X-Custom-Header": "value", + } + validate_headers(valid_headers) # Should not raise + + def test_newline_in_header_name(self): + """Test that newlines in header names are rejected (HTTP header injection).""" + with pytest.raises(MCPValidationError) as exc_info: + validate_headers({"X-Bad\nHeader": "value"}) + assert "newline" in exc_info.value.message.lower() + + def test_newline_in_header_value(self): + """Test that newlines in header values are rejected (HTTP header injection).""" + with pytest.raises(MCPValidationError): + validate_headers({"X-Header": "value\r\nX-Injected: malicious"}) + + def test_null_byte_in_header(self): + """Test that null bytes in headers are rejected.""" + with pytest.raises(MCPValidationError): + validate_headers({"X-Header": "value\x00"}) + + def test_empty_headers(self): + """Test that empty headers dict passes validation.""" + validate_headers({}) + validate_headers(None) + + +class TestValidateMCPServerConfig: + """Tests for the main validate_mcp_server_config function.""" + + def test_valid_stdio_config(self): + """Test valid stdio configuration.""" + validate_mcp_server_config( + transport="stdio", + command="npx", + args=["@modelcontextprotocol/server-filesystem"], + env={"API_KEY": "secret"}, + ) # Should not raise + + def test_valid_sse_config(self): + """Test valid SSE configuration.""" + validate_mcp_server_config( + transport="sse", + url="https://api.example.com/sse", + headers={"Authorization": "Bearer token"}, + ) # Should not raise + + def test_valid_http_config(self): + """Test valid streamable_http configuration.""" + validate_mcp_server_config( + transport="streamable_http", + url="https://api.example.com/mcp", + ) # Should not raise + + def test_invalid_transport(self): + """Test that invalid transport type raises an error.""" + with pytest.raises(MCPValidationError) as exc_info: + validate_mcp_server_config(transport="invalid") + assert "Invalid transport type" in exc_info.value.message + + def test_combined_validation_errors(self): + """Test that multiple validation errors are combined.""" + with pytest.raises(MCPValidationError) as exc_info: + validate_mcp_server_config( + transport="stdio", + command="bash", # Not allowed + args=["../etc/passwd"], # Path traversal + env={"PATH": "/malicious"}, # Dangerous env var + ) + # All errors should be combined + assert "not allowed" in exc_info.value.message + assert "traversal" in exc_info.value.message.lower() + + def test_non_strict_mode(self): + """Test that non-strict mode logs warnings instead of raising.""" + # Should not raise, but would log warnings + validate_mcp_server_config( + transport="stdio", + command="bash", + strict=False, + ) + + def test_stdio_with_dangerous_args(self): + """Test stdio config with command injection attempt.""" + with pytest.raises(MCPValidationError): + validate_mcp_server_config( + transport="stdio", + command="node", + args=["script.js; rm -rf /"], + ) + + def test_sse_with_invalid_url(self): + """Test SSE config with invalid URL.""" + with pytest.raises(MCPValidationError): + validate_mcp_server_config( + transport="sse", + url="ftp://example.com", + ) + + +class TestMCPServerMetadataRequest: + """Tests for Pydantic model validation.""" + + def test_valid_request(self): + """Test that valid request passes validation.""" + from src.server.mcp_request import MCPServerMetadataRequest + + request = MCPServerMetadataRequest( + transport="stdio", + command="npx", + args=["@modelcontextprotocol/server-filesystem"], + ) + assert request.transport == "stdio" + assert request.command == "npx" + + def test_invalid_command_raises_validation_error(self): + """Test that invalid command raises Pydantic ValidationError.""" + from pydantic import ValidationError + + from src.server.mcp_request import MCPServerMetadataRequest + + with pytest.raises(ValidationError) as exc_info: + MCPServerMetadataRequest( + transport="stdio", + command="bash", + ) + assert "not allowed" in str(exc_info.value).lower() + + def test_command_injection_raises_validation_error(self): + """Test that command injection raises Pydantic ValidationError.""" + from pydantic import ValidationError + + from src.server.mcp_request import MCPServerMetadataRequest + + with pytest.raises(ValidationError): + MCPServerMetadataRequest( + transport="stdio", + command="node", + args=["script.js; rm -rf /"], + ) + + def test_invalid_url_raises_validation_error(self): + """Test that invalid URL raises Pydantic ValidationError.""" + from pydantic import ValidationError + + from src.server.mcp_request import MCPServerMetadataRequest + + with pytest.raises(ValidationError): + MCPServerMetadataRequest( + transport="sse", + url="ftp://example.com", + )