mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat(server): add MCP server configuration validation (#830)
* feat(server): add MCP server configuration validation Add comprehensive validation for MCP server configurations, inspired by Flowise's validateMCPServerConfig implementation. MCPServerConfig checks implemented: - Command allowlist validation (node, npx, python, docker, uvx, etc.) - Path traversal prevention (blocks ../, absolute paths, ~/) - Shell command injection prevention (blocks ; & | ` $ etc.) - Dangerous environment variable blocking (PATH, LD_PRELOAD, etc.) - URL validation for SSE/HTTP transports (scheme, credentials) - HTTP header injection prevention (blocks newlines) * fix the unit test error of test_chat_request * Added the related path cases as reviewer commented * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -352,9 +352,9 @@ class TestMCPEndpoint:
|
||||
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {"ENV_VAR": "value"},
|
||||
"command": "node",
|
||||
"args": ["server.js"],
|
||||
"env": {"API_KEY": "test123"},
|
||||
}
|
||||
|
||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||
@@ -362,7 +362,7 @@ class TestMCPEndpoint:
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["transport"] == "stdio"
|
||||
assert response_data["command"] == "test_command"
|
||||
assert response_data["command"] == "node"
|
||||
assert len(response_data["tools"]) == 1
|
||||
|
||||
@patch("src.server.app.load_mcp_tools")
|
||||
@@ -375,7 +375,7 @@ class TestMCPEndpoint:
|
||||
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"command": "node",
|
||||
"timeout_seconds": 60,
|
||||
}
|
||||
|
||||
@@ -424,9 +424,9 @@ class TestMCPEndpoint:
|
||||
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {"ENV_VAR": "value"},
|
||||
"command": "node",
|
||||
"args": ["server.js"],
|
||||
"env": {"API_KEY": "test123"},
|
||||
}
|
||||
|
||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||
@@ -444,9 +444,9 @@ class TestMCPEndpoint:
|
||||
):
|
||||
request_data = {
|
||||
"transport": "stdio",
|
||||
"command": "test_command",
|
||||
"args": ["arg1", "arg2"],
|
||||
"env": {"ENV_VAR": "value"},
|
||||
"command": "node",
|
||||
"args": ["server.js"],
|
||||
"env": {"API_KEY": "test123"},
|
||||
}
|
||||
|
||||
response = client.post("/api/mcp/server/metadata", json=request_data)
|
||||
|
||||
@@ -163,6 +163,6 @@ async def test_load_mcp_tools_exception_handling(
|
||||
mock_stdio_client.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") # Use await
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="node") # Use await
|
||||
assert exc.value.status_code == 500
|
||||
assert "unexpected error" in exc.value.detail
|
||||
|
||||
@@ -55,14 +55,14 @@ async def test_load_mcp_tools_stdio_success(
|
||||
|
||||
result = await mcp_utils.load_mcp_tools(
|
||||
server_type="stdio",
|
||||
command="echo",
|
||||
args=["foo"],
|
||||
env={"FOO": "BAR"},
|
||||
command="node",
|
||||
args=["server.js"],
|
||||
env={"API_KEY": "test123"},
|
||||
timeout_seconds=3,
|
||||
)
|
||||
assert result == ["toolA"]
|
||||
mock_StdioServerParameters.assert_called_once_with(
|
||||
command="echo", args=["foo"], env={"FOO": "BAR"}
|
||||
command="node", args=["server.js"], env={"API_KEY": "test123"}
|
||||
)
|
||||
mock_stdio_client.assert_called_once_with(params)
|
||||
mock_get_tools.assert_awaited_once_with(mock_client, 3)
|
||||
@@ -165,7 +165,7 @@ async def test_load_mcp_tools_unsupported_type():
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="unknown")
|
||||
assert exc.value.status_code == 400
|
||||
assert "Unsupported server type" in exc.value.detail
|
||||
assert "Invalid transport type" in exc.value.detail or "Unsupported server type" in exc.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -180,6 +180,6 @@ async def test_load_mcp_tools_exception_handling(
|
||||
mock_stdio_client.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="foo")
|
||||
await mcp_utils.load_mcp_tools(server_type="stdio", command="node")
|
||||
assert exc.value.status_code == 500
|
||||
assert "unexpected error" in exc.value.detail
|
||||
|
||||
450
tests/unit/server/test_mcp_validators.py
Normal file
450
tests/unit/server/test_mcp_validators.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for MCP server configuration validators.
|
||||
|
||||
Tests cover:
|
||||
- Command validation (allowlist)
|
||||
- Argument validation (path traversal, command injection)
|
||||
- Environment variable validation
|
||||
- URL validation
|
||||
- Header validation
|
||||
- Full config validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.mcp_validators import (
|
||||
ALLOWED_COMMANDS,
|
||||
MCPValidationError,
|
||||
validate_args_for_local_file_access,
|
||||
validate_command,
|
||||
validate_command_injection,
|
||||
validate_environment_variables,
|
||||
validate_headers,
|
||||
validate_mcp_server_config,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateCommand:
|
||||
"""Tests for validate_command function."""
|
||||
|
||||
def test_allowed_commands(self):
|
||||
"""Test that all allowed commands pass validation."""
|
||||
for cmd in ALLOWED_COMMANDS:
|
||||
validate_command(cmd) # Should not raise
|
||||
|
||||
def test_allowed_command_with_path(self):
|
||||
"""Test that commands with paths are validated by base name."""
|
||||
validate_command("/usr/bin/python3")
|
||||
validate_command("/usr/local/bin/node")
|
||||
validate_command("C:\\Python\\python.exe")
|
||||
|
||||
def test_disallowed_command(self):
|
||||
"""Test that disallowed commands raise an error."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_command("bash")
|
||||
assert "not allowed" in exc_info.value.message
|
||||
assert exc_info.value.field == "command"
|
||||
|
||||
def test_disallowed_dangerous_commands(self):
|
||||
"""Test that dangerous commands are rejected."""
|
||||
dangerous_commands = ["rm", "sudo", "chmod", "chown", "curl", "wget", "sh"]
|
||||
for cmd in dangerous_commands:
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command(cmd)
|
||||
|
||||
def test_empty_command(self):
|
||||
"""Test that empty command raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command("")
|
||||
|
||||
def test_none_command(self):
|
||||
"""Test that None command raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command(None)
|
||||
|
||||
|
||||
class TestValidateArgsForLocalFileAccess:
|
||||
"""Tests for validate_args_for_local_file_access function."""
|
||||
|
||||
def test_safe_args(self):
|
||||
"""Test that safe arguments pass validation."""
|
||||
safe_args = [
|
||||
["--help"],
|
||||
["-v", "--verbose"],
|
||||
["package-name"],
|
||||
["--config", "config.json"],
|
||||
["run", "script.py"],
|
||||
]
|
||||
for args in safe_args:
|
||||
validate_args_for_local_file_access(args) # Should not raise
|
||||
|
||||
def test_directory_traversal(self):
|
||||
"""Test that directory traversal patterns are rejected."""
|
||||
traversal_patterns = [
|
||||
["../etc/passwd"],
|
||||
["..\\windows\\system32"],
|
||||
["../../secret"],
|
||||
["foo/../bar/../../../etc/passwd"],
|
||||
["foo/.."], # ".." at end after path separator
|
||||
["bar\\.."], # ".." at end after Windows path separator
|
||||
["path/to/foo/.."], # Longer path ending with ".."
|
||||
]
|
||||
for args in traversal_patterns:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(args)
|
||||
assert "traversal" in exc_info.value.message.lower()
|
||||
|
||||
def test_absolute_path_with_dangerous_extension(self):
|
||||
"""Test that absolute paths with dangerous extensions are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["/etc/passwd.sh"])
|
||||
|
||||
def test_windows_absolute_path(self):
|
||||
"""Test that Windows absolute paths are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["C:\\Windows\\system32"])
|
||||
|
||||
def test_home_directory_reference(self):
|
||||
"""Test that home directory references are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["~/secrets"])
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_args_for_local_file_access(["~\\secrets"])
|
||||
|
||||
def test_null_byte(self):
|
||||
"""Test that null bytes in arguments are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(["file\x00.txt"])
|
||||
assert "null byte" in exc_info.value.message.lower()
|
||||
|
||||
def test_excessively_long_argument(self):
|
||||
"""Test that excessively long arguments are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(["a" * 1001])
|
||||
assert "maximum length" in exc_info.value.message.lower()
|
||||
|
||||
def test_dangerous_extensions(self):
|
||||
"""Test that dangerous file extensions are rejected."""
|
||||
dangerous_files = [
|
||||
["script.sh"],
|
||||
["binary.exe"],
|
||||
["library.dll"],
|
||||
["secret.env"],
|
||||
["key.pem"],
|
||||
]
|
||||
for args in dangerous_files:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_args_for_local_file_access(args)
|
||||
assert "dangerous file type" in exc_info.value.message.lower()
|
||||
|
||||
def test_empty_args(self):
|
||||
"""Test that empty args list passes validation."""
|
||||
validate_args_for_local_file_access([])
|
||||
validate_args_for_local_file_access(None)
|
||||
|
||||
|
||||
class TestValidateCommandInjection:
|
||||
"""Tests for validate_command_injection function."""
|
||||
|
||||
def test_safe_args(self):
|
||||
"""Test that safe arguments pass validation."""
|
||||
safe_args = [
|
||||
["--help"],
|
||||
["package-name"],
|
||||
["@scope/package"],
|
||||
["file.json"],
|
||||
]
|
||||
for args in safe_args:
|
||||
validate_command_injection(args) # Should not raise
|
||||
|
||||
def test_shell_metacharacters(self):
|
||||
"""Test that shell metacharacters are rejected."""
|
||||
metachar_args = [
|
||||
["foo; rm -rf /"],
|
||||
["foo & bar"],
|
||||
["foo | cat /etc/passwd"],
|
||||
["$(whoami)"],
|
||||
["`id`"],
|
||||
["foo > /etc/passwd"],
|
||||
["foo < /etc/passwd"],
|
||||
["${PATH}"],
|
||||
]
|
||||
for args in metachar_args:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_command_injection(args)
|
||||
assert "args" == exc_info.value.field
|
||||
|
||||
def test_command_chaining(self):
|
||||
"""Test that command chaining patterns are rejected."""
|
||||
chaining_args = [
|
||||
["foo && bar"],
|
||||
["foo || bar"],
|
||||
["foo;; bar"],
|
||||
["foo >> output"],
|
||||
["foo << input"],
|
||||
]
|
||||
for args in chaining_args:
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(args)
|
||||
|
||||
def test_backtick_injection(self):
|
||||
"""Test that backtick command substitution is rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(["`whoami`"])
|
||||
|
||||
def test_process_substitution(self):
|
||||
"""Test that process substitution is rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection(["<(cat /etc/passwd)"])
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_command_injection([">(tee /tmp/out)"])
|
||||
|
||||
|
||||
class TestValidateEnvironmentVariables:
|
||||
"""Tests for validate_environment_variables function."""
|
||||
|
||||
def test_safe_env_vars(self):
|
||||
"""Test that safe environment variables pass validation."""
|
||||
safe_env = {
|
||||
"API_KEY": "secret123",
|
||||
"DEBUG": "true",
|
||||
"MY_VARIABLE": "value",
|
||||
}
|
||||
validate_environment_variables(safe_env) # Should not raise
|
||||
|
||||
def test_dangerous_env_vars(self):
|
||||
"""Test that dangerous environment variables are rejected."""
|
||||
dangerous_vars = [
|
||||
{"PATH": "/malicious/path"},
|
||||
{"LD_LIBRARY_PATH": "/malicious/lib"},
|
||||
{"DYLD_LIBRARY_PATH": "/malicious/lib"},
|
||||
{"LD_PRELOAD": "/malicious/lib.so"},
|
||||
{"PYTHONPATH": "/malicious/python"},
|
||||
{"NODE_PATH": "/malicious/node"},
|
||||
]
|
||||
for env in dangerous_vars:
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_environment_variables(env)
|
||||
assert "not allowed" in exc_info.value.message.lower()
|
||||
|
||||
def test_null_byte_in_value(self):
|
||||
"""Test that null bytes in values are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_environment_variables({"KEY": "value\x00malicious"})
|
||||
|
||||
def test_empty_env(self):
|
||||
"""Test that empty env dict passes validation."""
|
||||
validate_environment_variables({})
|
||||
validate_environment_variables(None)
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Tests for validate_url function."""
|
||||
|
||||
def test_valid_urls(self):
|
||||
"""Test that valid URLs pass validation."""
|
||||
valid_urls = [
|
||||
"http://localhost:3000",
|
||||
"https://api.example.com",
|
||||
"http://192.168.1.1:8080/api",
|
||||
"https://mcp.example.com/sse",
|
||||
]
|
||||
for url in valid_urls:
|
||||
validate_url(url) # Should not raise
|
||||
|
||||
def test_invalid_scheme(self):
|
||||
"""Test that invalid URL schemes are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_url("ftp://example.com")
|
||||
assert "scheme" in exc_info.value.message.lower()
|
||||
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("file:///etc/passwd")
|
||||
|
||||
def test_credentials_in_url(self):
|
||||
"""Test that URLs with credentials are rejected."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_url("https://user:pass@example.com")
|
||||
assert "credentials" in exc_info.value.message.lower()
|
||||
|
||||
def test_null_byte_in_url(self):
|
||||
"""Test that null bytes in URL are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("https://example.com\x00/malicious")
|
||||
|
||||
def test_empty_url(self):
|
||||
"""Test that empty URL raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("")
|
||||
|
||||
def test_no_host(self):
|
||||
"""Test that URL without host raises an error."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_url("http:///path")
|
||||
|
||||
|
||||
class TestValidateHeaders:
|
||||
"""Tests for validate_headers function."""
|
||||
|
||||
def test_valid_headers(self):
|
||||
"""Test that valid headers pass validation."""
|
||||
valid_headers = {
|
||||
"Authorization": "Bearer token123",
|
||||
"Content-Type": "application/json",
|
||||
"X-Custom-Header": "value",
|
||||
}
|
||||
validate_headers(valid_headers) # Should not raise
|
||||
|
||||
def test_newline_in_header_name(self):
|
||||
"""Test that newlines in header names are rejected (HTTP header injection)."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_headers({"X-Bad\nHeader": "value"})
|
||||
assert "newline" in exc_info.value.message.lower()
|
||||
|
||||
def test_newline_in_header_value(self):
|
||||
"""Test that newlines in header values are rejected (HTTP header injection)."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_headers({"X-Header": "value\r\nX-Injected: malicious"})
|
||||
|
||||
def test_null_byte_in_header(self):
|
||||
"""Test that null bytes in headers are rejected."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_headers({"X-Header": "value\x00"})
|
||||
|
||||
def test_empty_headers(self):
|
||||
"""Test that empty headers dict passes validation."""
|
||||
validate_headers({})
|
||||
validate_headers(None)
|
||||
|
||||
|
||||
class TestValidateMCPServerConfig:
|
||||
"""Tests for the main validate_mcp_server_config function."""
|
||||
|
||||
def test_valid_stdio_config(self):
|
||||
"""Test valid stdio configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=["@modelcontextprotocol/server-filesystem"],
|
||||
env={"API_KEY": "secret"},
|
||||
) # Should not raise
|
||||
|
||||
def test_valid_sse_config(self):
|
||||
"""Test valid SSE configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="sse",
|
||||
url="https://api.example.com/sse",
|
||||
headers={"Authorization": "Bearer token"},
|
||||
) # Should not raise
|
||||
|
||||
def test_valid_http_config(self):
|
||||
"""Test valid streamable_http configuration."""
|
||||
validate_mcp_server_config(
|
||||
transport="streamable_http",
|
||||
url="https://api.example.com/mcp",
|
||||
) # Should not raise
|
||||
|
||||
def test_invalid_transport(self):
|
||||
"""Test that invalid transport type raises an error."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_mcp_server_config(transport="invalid")
|
||||
assert "Invalid transport type" in exc_info.value.message
|
||||
|
||||
def test_combined_validation_errors(self):
|
||||
"""Test that multiple validation errors are combined."""
|
||||
with pytest.raises(MCPValidationError) as exc_info:
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="bash", # Not allowed
|
||||
args=["../etc/passwd"], # Path traversal
|
||||
env={"PATH": "/malicious"}, # Dangerous env var
|
||||
)
|
||||
# All errors should be combined
|
||||
assert "not allowed" in exc_info.value.message
|
||||
assert "traversal" in exc_info.value.message.lower()
|
||||
|
||||
def test_non_strict_mode(self):
|
||||
"""Test that non-strict mode logs warnings instead of raising."""
|
||||
# Should not raise, but would log warnings
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="bash",
|
||||
strict=False,
|
||||
)
|
||||
|
||||
def test_stdio_with_dangerous_args(self):
|
||||
"""Test stdio config with command injection attempt."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_mcp_server_config(
|
||||
transport="stdio",
|
||||
command="node",
|
||||
args=["script.js; rm -rf /"],
|
||||
)
|
||||
|
||||
def test_sse_with_invalid_url(self):
|
||||
"""Test SSE config with invalid URL."""
|
||||
with pytest.raises(MCPValidationError):
|
||||
validate_mcp_server_config(
|
||||
transport="sse",
|
||||
url="ftp://example.com",
|
||||
)
|
||||
|
||||
|
||||
class TestMCPServerMetadataRequest:
|
||||
"""Tests for Pydantic model validation."""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""Test that valid request passes validation."""
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
request = MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=["@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
assert request.transport == "stdio"
|
||||
assert request.command == "npx"
|
||||
|
||||
def test_invalid_command_raises_validation_error(self):
|
||||
"""Test that invalid command raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="bash",
|
||||
)
|
||||
assert "not allowed" in str(exc_info.value).lower()
|
||||
|
||||
def test_command_injection_raises_validation_error(self):
|
||||
"""Test that command injection raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MCPServerMetadataRequest(
|
||||
transport="stdio",
|
||||
command="node",
|
||||
args=["script.js; rm -rf /"],
|
||||
)
|
||||
|
||||
def test_invalid_url_raises_validation_error(self):
|
||||
"""Test that invalid URL raises Pydantic ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.mcp_request import MCPServerMetadataRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MCPServerMetadataRequest(
|
||||
transport="sse",
|
||||
url="ftp://example.com",
|
||||
)
|
||||
Reference in New Issue
Block a user