diff --git a/.env.example b/.env.example index ef359e4..4b655ff 100644 --- a/.env.example +++ b/.env.example @@ -12,6 +12,11 @@ AGENT_RECURSION_LIMIT=30 # Example: ALLOWED_ORIGINS=http://localhost:3000,http://example.com ALLOWED_ORIGINS=http://localhost:3000 +# Enable or disable MCP server configuration, the default is false. +# Please enable this feature before securing your front-end and back-end in an internal environment. +# Otherwise, you system could be compromised. +ENABLE_MCP_SERVER_CONFIGURATION=false + # Search Engine, Supported values: tavily (recommended), duckduckgo, brave_search, arxiv SEARCH_API=tavily TAVILY_API_KEY=tvly-xxx diff --git a/docs/mcp_integrations.md b/docs/mcp_integrations.md index f54cd9f..76bb773 100644 --- a/docs/mcp_integrations.md +++ b/docs/mcp_integrations.md @@ -1,5 +1,8 @@ # MCP Integrations +This feature is diabled by default. You can enable it by setting the environment ENABLE_MCP_SERVER_CONFIGURATION +Please enable this feature before securing your frond-end and back-end in an internal environment.q + ## Example of MCP Server Configuration ```json diff --git a/src/server/app.py b/src/server/app.py index 7b07758..b457e3c 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -71,6 +71,20 @@ graph = build_graph_with_memory() @app.post("/api/chat/stream") async def chat_stream(request: ChatRequest): + # Check if MCP server configuration is enabled + mcp_enabled = os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() in [ + "true", + "1", + "yes", + ] + + # Validate MCP settings if provided + if request.mcp_settings and not mcp_enabled: + raise HTTPException( + status_code=403, + detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.", + ) + thread_id = request.thread_id if thread_id == "__default__": thread_id = str(uuid4()) @@ -84,7 +98,7 @@ async def chat_stream(request: ChatRequest): request.max_search_results, request.auto_accepted_plan, request.interrupt_feedback, - request.mcp_settings, + request.mcp_settings if mcp_enabled else {}, request.enable_background_investigation, request.report_style, request.enable_deep_thinking, @@ -363,6 +377,17 @@ async def enhance_prompt(request: EnhancePromptRequest): @app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse) async def mcp_server_metadata(request: MCPServerMetadataRequest): """Get information about an MCP server.""" + # Check if MCP server configuration is enabled + if os.getenv("ENABLE_MCP_SERVER_CONFIGURATION", "false").lower() not in [ + "true", + "1", + "yes", + ]: + raise HTTPException( + status_code=403, + detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable.", + ) + try: # Set default timeout with a longer value for this endpoint timeout = 300 # Default to 300 seconds for this endpoint diff --git a/tests/unit/server/test_app.py b/tests/unit/server/test_app.py index f6eb503..e72cb48 100644 --- a/tests/unit/server/test_app.py +++ b/tests/unit/server/test_app.py @@ -260,6 +260,10 @@ class TestEnhancePromptEndpoint: class TestMCPEndpoint: @patch("src.server.app.load_mcp_tools") + @patch.dict( + os.environ, + {"ENABLE_MCP_SERVER_CONFIGURATION": "true"}, + ) def test_mcp_server_metadata_success(self, mock_load_tools, client): mock_load_tools.return_value = [ {"name": "test_tool", "description": "Test tool"} @@ -281,6 +285,10 @@ class TestMCPEndpoint: assert len(response_data["tools"]) == 1 @patch("src.server.app.load_mcp_tools") + @patch.dict( + os.environ, + {"ENABLE_MCP_SERVER_CONFIGURATION": "true"}, + ) def test_mcp_server_metadata_with_custom_timeout(self, mock_load_tools, client): mock_load_tools.return_value = [] @@ -296,6 +304,10 @@ class TestMCPEndpoint: mock_load_tools.assert_called_once() @patch("src.server.app.load_mcp_tools") + @patch.dict( + os.environ, + {"ENABLE_MCP_SERVER_CONFIGURATION": "true"}, + ) def test_mcp_server_metadata_with_exception(self, mock_load_tools, client): mock_load_tools.side_effect = HTTPException( status_code=400, detail="MCP Server Error" @@ -313,6 +325,30 @@ class TestMCPEndpoint: assert response.status_code == 500 assert response.json()["detail"] == "Internal Server Error" + @patch("src.server.app.load_mcp_tools") + @patch.dict( + os.environ, + {"ENABLE_MCP_SERVER_CONFIGURATION": ""}, + ) + def test_mcp_server_metadata_without_enable_configuration( + self, mock_load_tools, client + ): + + request_data = { + "transport": "stdio", + "command": "test_command", + "args": ["arg1", "arg2"], + "env": {"ENV_VAR": "value"}, + } + + response = client.post("/api/mcp/server/metadata", json=request_data) + + assert response.status_code == 403 + assert ( + response.json()["detail"] + == "MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable." + ) + class TestRAGEndpoints: @patch("src.server.app.SELECTED_RAG_PROVIDER", "test_provider")