From dcdd7288ed0c861551c3c1669c1ebcff8675a849 Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Wed, 18 Jun 2025 14:13:05 +0800 Subject: [PATCH] test: add unit tests of the app (#305) * test: add unit tests in server * test: add unit tests of app.py in server * test: reformat the codes * test: add more tests to cover the exception part * test: add more tests on the server app part * fix: don't show the detail exception to the client * test: try to fix the CI test * fix: keep the TTS API call without exposure information * Fixed the unit test errors * Fixed the lint error --- src/server/app.py | 26 +- src/tools/tts.py | 2 +- tests/integration/test_tts.py | 3 +- tests/unit/server/test_app.py | 736 +++++++++++++++++++++++++ tests/unit/server/test_chat_request.py | 168 ++++++ tests/unit/server/test_mcp_request.py | 73 +++ tests/unit/server/test_mcp_utils.py | 121 ++++ 7 files changed, 1113 insertions(+), 16 deletions(-) create mode 100644 tests/unit/server/test_app.py create mode 100644 tests/unit/server/test_chat_request.py create mode 100644 tests/unit/server/test_mcp_request.py create mode 100644 tests/unit/server/test_mcp_utils.py diff --git a/src/server/app.py b/src/server/app.py index 57db6ff..71b4009 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -201,17 +201,16 @@ def _make_event(event_type: str, data: dict[str, any]): @app.post("/api/tts") async def text_to_speech(request: TTSRequest): """Convert text to speech using volcengine TTS API.""" + app_id = os.getenv("VOLCENGINE_TTS_APPID", "") + if not app_id: + raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set") + access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "") + if not access_token: + raise HTTPException( + status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set" + ) + try: - app_id = os.getenv("VOLCENGINE_TTS_APPID", "") - if not app_id: - raise HTTPException( - status_code=400, detail="VOLCENGINE_TTS_APPID is not set" - ) - access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "") - if not access_token: - raise HTTPException( - status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set" - ) cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts") voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming") @@ -249,6 +248,7 @@ async def text_to_speech(request: TTSRequest): ) }, ) + except Exception as e: logger.exception(f"Error in TTS endpoint: {str(e)}") raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL) @@ -388,10 +388,8 @@ async def mcp_server_metadata(request: MCPServerMetadataRequest): return response except Exception as e: - if not isinstance(e, HTTPException): - logger.exception(f"Error in MCP server metadata endpoint: {str(e)}") - raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL) - raise + logger.exception(f"Error in MCP server metadata endpoint: {str(e)}") + raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL) @app.get("/api/rag/config", response_model=RAGConfigResponse) diff --git a/src/tools/tts.py b/src/tools/tts.py index 9a570bb..d5f2565 100644 --- a/src/tools/tts.py +++ b/src/tools/tts.py @@ -129,4 +129,4 @@ class VolcengineTTS: except Exception as e: logger.exception(f"Error in TTS API call: {str(e)}") - return {"success": False, "error": str(e), "audio_data": None} + return {"success": False, "error": "TTS API call error", "audio_data": None} diff --git a/tests/integration/test_tts.py b/tests/integration/test_tts.py index 1066c95..dc126de 100644 --- a/tests/integration/test_tts.py +++ b/tests/integration/test_tts.py @@ -244,5 +244,6 @@ class TestVolcengineTTS: result = tts.text_to_speech("Hello, world!") # Verify the result assert result["success"] is False - assert result["error"] == "Network error" + # The TTS error is caught and returned as a string + assert result["error"] == "TTS API call error" assert result["audio_data"] is None diff --git a/tests/unit/server/test_app.py b/tests/unit/server/test_app.py new file mode 100644 index 0000000..0016b50 --- /dev/null +++ b/tests/unit/server/test_app.py @@ -0,0 +1,736 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import base64 +import json +import os +from unittest.mock import AsyncMock, MagicMock, patch, mock_open +from uuid import uuid4 +from fastapi.responses import JSONResponse, StreamingResponse +import pytest +from fastapi.testclient import TestClient +from fastapi import HTTPException, logger +from src.server.app import app, _make_event, _astream_workflow_generator +from src.server.mcp_request import MCPServerMetadataRequest +from src.server.rag_request import RAGResourceRequest +from src.config.report_style import ReportStyle +from langgraph.types import Command +from langchain_core.messages import ToolMessage +from langchain_core.messages import AIMessageChunk + +from src.server.chat_request import ( + ChatRequest, + TTSRequest, + GeneratePodcastRequest, + GeneratePPTRequest, + GenerateProseRequest, + EnhancePromptRequest, +) + + +@pytest.fixture +def client(): + return TestClient(app) + + +class TestMakeEvent: + def test_make_event_with_content(self): + event_type = "message_chunk" + data = {"content": "Hello", "role": "assistant"} + result = _make_event(event_type, data) + expected = ( + 'event: message_chunk\ndata: {"content": "Hello", "role": "assistant"}\n\n' + ) + assert result == expected + + def test_make_event_with_empty_content(self): + event_type = "message_chunk" + data = {"content": "", "role": "assistant"} + result = _make_event(event_type, data) + expected = 'event: message_chunk\ndata: {"role": "assistant"}\n\n' + assert result == expected + + def test_make_event_without_content(self): + event_type = "tool_calls" + data = {"role": "assistant", "tool_calls": []} + result = _make_event(event_type, data) + expected = ( + 'event: tool_calls\ndata: {"role": "assistant", "tool_calls": []}\n\n' + ) + assert result == expected + + +class TestTTSEndpoint: + @patch.dict( + os.environ, + { + "VOLCENGINE_TTS_APPID": "test_app_id", + "VOLCENGINE_TTS_ACCESS_TOKEN": "test_token", + "VOLCENGINE_TTS_CLUSTER": "test_cluster", + "VOLCENGINE_TTS_VOICE_TYPE": "test_voice", + }, + ) + @patch("src.server.app.VolcengineTTS") + def test_tts_success(self, mock_tts_class, client): + mock_tts_instance = MagicMock() + mock_tts_class.return_value = mock_tts_instance + + # Mock successful TTS response + audio_data_b64 = base64.b64encode(b"fake_audio_data").decode() + mock_tts_instance.text_to_speech.return_value = { + "success": True, + "audio_data": audio_data_b64, + } + + request_data = { + "text": "Hello world", + "encoding": "mp3", + "speed_ratio": 1.0, + "volume_ratio": 1.0, + "pitch_ratio": 1.0, + "text_type": "plain", + "with_frontend": True, + "frontend_type": "unitTson", + } + + response = client.post("/api/tts", json=request_data) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mp3" + assert b"fake_audio_data" in response.content + + @patch.dict(os.environ, {}, clear=True) + def test_tts_missing_app_id(self, client): + request_data = {"text": "Hello world", "encoding": "mp3"} + + response = client.post("/api/tts", json=request_data) + + assert response.status_code == 400 + assert "VOLCENGINE_TTS_APPID is not set" in response.json()["detail"] + + @patch.dict( + os.environ, + {"VOLCENGINE_TTS_APPID": "test_app_id", "VOLCENGINE_TTS_ACCESS_TOKEN": ""}, + ) + def test_tts_missing_access_token(self, client): + request_data = {"text": "Hello world", "encoding": "mp3"} + + response = client.post("/api/tts", json=request_data) + + assert response.status_code == 400 + assert "VOLCENGINE_TTS_ACCESS_TOKEN is not set" in response.json()["detail"] + + @patch.dict( + os.environ, + { + "VOLCENGINE_TTS_APPID": "test_app_id", + "VOLCENGINE_TTS_ACCESS_TOKEN": "test_token", + }, + ) + @patch("src.server.app.VolcengineTTS") + def test_tts_api_error(self, mock_tts_class, client): + mock_tts_instance = MagicMock() + mock_tts_class.return_value = mock_tts_instance + + # Mock TTS error response + mock_tts_instance.text_to_speech.return_value = { + "success": False, + "error": "TTS API error", + } + + request_data = {"text": "Hello world", "encoding": "mp3"} + + response = client.post("/api/tts", json=request_data) + + assert response.status_code == 500 + assert "Internal Server Error" in response.json()["detail"] + + @pytest.mark.skip(reason="TTS server exception is catched") + @patch("src.server.app.VolcengineTTS") + def test_tts_api_exception(self, mock_tts_class, client): + mock_tts_instance = MagicMock() + mock_tts_class.return_value = mock_tts_instance + + # Mock TTS error response + mock_tts_instance.side_effect = Exception("TTS API error") + + request_data = {"text": "Hello world", "encoding": "mp3"} + + response = client.post("/api/tts", json=request_data) + + assert response.status_code == 500 + assert "Internal Server Error" in response.json()["detail"] + + +class TestPodcastEndpoint: + @patch("src.server.app.build_podcast_graph") + def test_generate_podcast_success(self, mock_build_graph, client): + mock_workflow = MagicMock() + mock_build_graph.return_value = mock_workflow + mock_workflow.invoke.return_value = {"output": b"fake_audio_data"} + + request_data = {"content": "Test content for podcast"} + + response = client.post("/api/podcast/generate", json=request_data) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mp3" + assert response.content == b"fake_audio_data" + + @patch("src.server.app.build_podcast_graph") + def test_generate_podcast_error(self, mock_build_graph, client): + mock_build_graph.side_effect = Exception("Podcast generation failed") + + request_data = {"content": "Test content"} + + response = client.post("/api/podcast/generate", json=request_data) + + assert response.status_code == 500 + assert response.json()["detail"] == "Internal Server Error" + + +class TestPPTEndpoint: + @patch("src.server.app.build_ppt_graph") + @patch("builtins.open", new_callable=mock_open, read_data=b"fake_ppt_data") + def test_generate_ppt_success(self, mock_file, mock_build_graph, client): + mock_workflow = MagicMock() + mock_build_graph.return_value = mock_workflow + mock_workflow.invoke.return_value = { + "generated_file_path": "/fake/path/test.pptx" + } + + request_data = {"content": "Test content for PPT"} + + response = client.post("/api/ppt/generate", json=request_data) + + assert response.status_code == 200 + assert ( + "application/vnd.openxmlformats-officedocument.presentationml.presentation" + in response.headers["content-type"] + ) + assert response.content == b"fake_ppt_data" + + @patch("src.server.app.build_ppt_graph") + def test_generate_ppt_error(self, mock_build_graph, client): + mock_build_graph.side_effect = Exception("PPT generation failed") + + request_data = {"content": "Test content"} + + response = client.post("/api/ppt/generate", json=request_data) + + assert response.status_code == 500 + assert response.json()["detail"] == "Internal Server Error" + + +class TestEnhancePromptEndpoint: + @patch("src.server.app.build_prompt_enhancer_graph") + def test_enhance_prompt_success(self, mock_build_graph, client): + mock_workflow = MagicMock() + mock_build_graph.return_value = mock_workflow + mock_workflow.invoke.return_value = {"output": "Enhanced prompt"} + + request_data = { + "prompt": "Original prompt", + "context": "Some context", + "report_style": "academic", + } + + response = client.post("/api/prompt/enhance", json=request_data) + + assert response.status_code == 200 + assert response.json()["result"] == "Enhanced prompt" + + @patch("src.server.app.build_prompt_enhancer_graph") + def test_enhance_prompt_with_different_styles(self, mock_build_graph, client): + mock_workflow = MagicMock() + mock_build_graph.return_value = mock_workflow + mock_workflow.invoke.return_value = {"output": "Enhanced prompt"} + + styles = [ + "ACADEMIC", + "popular_science", + "NEWS", + "social_media", + "invalid_style", + ] + + for style in styles: + request_data = {"prompt": "Test prompt", "report_style": style} + + response = client.post("/api/prompt/enhance", json=request_data) + assert response.status_code == 200 + + @patch("src.server.app.build_prompt_enhancer_graph") + def test_enhance_prompt_error(self, mock_build_graph, client): + mock_build_graph.side_effect = Exception("Enhancement failed") + + request_data = {"prompt": "Test prompt"} + + response = client.post("/api/prompt/enhance", json=request_data) + + assert response.status_code == 500 + assert response.json()["detail"] == "Internal Server Error" + + +class TestMCPEndpoint: + @patch("src.server.app.load_mcp_tools") + def test_mcp_server_metadata_success(self, mock_load_tools, client): + mock_load_tools.return_value = [ + {"name": "test_tool", "description": "Test tool"} + ] + + request_data = { + "transport": "stdio", + "command": "test_command", + "args": ["arg1", "arg2"], + "env": {"ENV_VAR": "value"}, + } + + response = client.post("/api/mcp/server/metadata", json=request_data) + + assert response.status_code == 200 + response_data = response.json() + assert response_data["transport"] == "stdio" + assert response_data["command"] == "test_command" + assert len(response_data["tools"]) == 1 + + @patch("src.server.app.load_mcp_tools") + def test_mcp_server_metadata_with_custom_timeout(self, mock_load_tools, client): + mock_load_tools.return_value = [] + + request_data = { + "transport": "stdio", + "command": "test_command", + "timeout_seconds": 600, + } + + response = client.post("/api/mcp/server/metadata", json=request_data) + + assert response.status_code == 200 + mock_load_tools.assert_called_once() + + @patch("src.server.app.load_mcp_tools") + def test_mcp_server_metadata_with_exception(self, mock_load_tools, client): + mock_load_tools.side_effect = HTTPException( + status_code=400, detail="MCP Server Error" + ) + + request_data = { + "transport": "stdio", + "command": "test_command", + "args": ["arg1", "arg2"], + "env": {"ENV_VAR": "value"}, + } + + response = client.post("/api/mcp/server/metadata", json=request_data) + + assert response.status_code == 500 + assert response.json()["detail"] == "Internal Server Error" + + +class TestRAGEndpoints: + @patch("src.server.app.SELECTED_RAG_PROVIDER", "test_provider") + def test_rag_config(self, client): + response = client.get("/api/rag/config") + + assert response.status_code == 200 + assert response.json()["provider"] == "test_provider" + + @patch("src.server.app.build_retriever") + def test_rag_resources_with_retriever(self, mock_build_retriever, client): + mock_retriever = MagicMock() + mock_retriever.list_resources.return_value = [ + { + "uri": "test_uri", + "title": "Test Resource", + "description": "Test Description", + } + ] + mock_build_retriever.return_value = mock_retriever + + response = client.get("/api/rag/resources?query=test") + + assert response.status_code == 200 + assert len(response.json()["resources"]) == 1 + + @patch("src.server.app.build_retriever") + def test_rag_resources_without_retriever(self, mock_build_retriever, client): + mock_build_retriever.return_value = None + + response = client.get("/api/rag/resources") + + assert response.status_code == 200 + assert response.json()["resources"] == [] + + +class TestChatStreamEndpoint: + @patch("src.server.app.graph") + def test_chat_stream_with_default_thread_id(self, mock_graph, client): + # Mock the async stream + async def mock_astream(*args, **kwargs): + yield ("agent1", "step1", {"test": "data"}) + + mock_graph.astream = mock_astream + + request_data = { + "thread_id": "__default__", + "messages": [{"role": "user", "content": "Hello"}], + "resources": [], + "max_plan_iterations": 3, + "max_step_num": 10, + "max_search_results": 5, + "auto_accepted_plan": True, + "interrupt_feedback": "", + "mcp_settings": {}, + "enable_background_investigation": False, + "report_style": "academic", + } + + response = client.post("/api/chat/stream", json=request_data) + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + +class TestAstreamWorkflowGenerator: + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_basic_flow(self, mock_graph): + # Mock AI message chunk + mock_message = AIMessageChunk(content="Hello world") + mock_message.id = "msg_123" + mock_message.response_metadata = {} + mock_message.tool_calls = [] + mock_message.tool_call_chunks = [] + + # Mock the async stream - yield messages in the correct format + async def mock_astream(*args, **kwargs): + # Yield a tuple (message, metadata) instead of just [message] + yield ("agent1:subagent", "messages", (mock_message, {})) + + mock_graph.astream = mock_astream + + messages = [{"role": "user", "content": "Hello"}] + thread_id = "test_thread" + resources = [] + + generator = _astream_workflow_generator( + messages=messages, + thread_id=thread_id, + resources=resources, + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + ) + + events = [] + async for event in generator: + events.append(event) + + assert len(events) == 1 + assert "event: message_chunk" in events[0] + assert "Hello world" in events[0] + # Check for the actual agent name that appears in the output + assert '"agent": "a"' in events[0] + + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_with_interrupt_feedback(self, mock_graph): + + # Mock the async stream + async def mock_astream(*args, **kwargs): + # Verify that Command is passed as input when interrupt_feedback is provided + assert isinstance(args[0], Command) + assert "[edit_plan] Hello" in args[0].resume + yield ("agent1", "step1", {"test": "data"}) + + mock_graph.astream = mock_astream + + messages = [{"role": "user", "content": "Hello"}] + + generator = _astream_workflow_generator( + messages=messages, + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=False, + interrupt_feedback="edit_plan", + mcp_settings={}, + enable_background_investigation=False, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + ) + + events = [] + async for event in generator: + events.append(event) + + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_interrupt_event(self, mock_graph): + # Mock interrupt data + mock_interrupt = MagicMock() + mock_interrupt.ns = ["interrupt_id"] + mock_interrupt.value = "Plan requires approval" + + interrupt_data = {"__interrupt__": [mock_interrupt]} + + async def mock_astream(*args, **kwargs): + yield ("agent1", "step1", interrupt_data) + + mock_graph.astream = mock_astream + + generator = _astream_workflow_generator( + messages=[], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + ) + + events = [] + async for event in generator: + events.append(event) + + assert len(events) == 1 + assert "event: interrupt" in events[0] + assert "Plan requires approval" in events[0] + assert "interrupt_id" in events[0] + + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_tool_message(self, mock_graph): + + # Mock tool message + mock_tool_message = ToolMessage(content="Tool result", tool_call_id="tool_123") + mock_tool_message.id = "msg_456" + + async def mock_astream(*args, **kwargs): + yield ("agent1:subagent", "step1", (mock_tool_message, {})) + + mock_graph.astream = mock_astream + + generator = _astream_workflow_generator( + messages=[], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + ) + + events = [] + async for event in generator: + events.append(event) + + assert len(events) == 1 + assert "event: tool_call_result" in events[0] + assert "Tool result" in events[0] + assert "tool_123" in events[0] + + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_ai_message_with_tool_calls( + self, mock_graph + ): + + # Mock AI message with tool calls + mock_ai_message = AIMessageChunk(content="Making tool call") + mock_ai_message.id = "msg_789" + mock_ai_message.response_metadata = {"finish_reason": "tool_calls"} + mock_ai_message.tool_calls = [{"name": "search", "args": {"query": "test"}}] + mock_ai_message.tool_call_chunks = [{"name": "search"}] + + async def mock_astream(*args, **kwargs): + yield ("agent1:subagent", "step1", (mock_ai_message, {})) + + mock_graph.astream = mock_astream + + generator = _astream_workflow_generator( + messages=[], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + ) + + events = [] + async for event in generator: + events.append(event) + + assert len(events) == 1 + assert "event: tool_calls" in events[0] + assert "Making tool call" in events[0] + assert "tool_calls" in events[0] + + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_ai_message_with_tool_call_chunks( + self, mock_graph + ): + + # Mock AI message with only tool call chunks + mock_ai_message = AIMessageChunk(content="Streaming tool call") + mock_ai_message.id = "msg_101" + mock_ai_message.response_metadata = {} + mock_ai_message.tool_calls = [] + mock_ai_message.tool_call_chunks = [{"name": "search", "index": 0}] + + async def mock_astream(*args, **kwargs): + yield ("agent1:subagent", "step1", (mock_ai_message, {})) + + mock_graph.astream = mock_astream + + generator = _astream_workflow_generator( + messages=[], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + ) + + events = [] + async for event in generator: + events.append(event) + + assert len(events) == 1 + assert "event: tool_call_chunks" in events[0] + assert "Streaming tool call" in events[0] + + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_with_finish_reason(self, mock_graph): + + # Mock AI message with finish reason + mock_ai_message = AIMessageChunk(content="Complete response") + mock_ai_message.id = "msg_finish" + mock_ai_message.response_metadata = {"finish_reason": "stop"} + mock_ai_message.tool_calls = [] + mock_ai_message.tool_call_chunks = [] + + async def mock_astream(*args, **kwargs): + yield ("agent1:subagent", "step1", (mock_ai_message, {})) + + mock_graph.astream = mock_astream + + generator = _astream_workflow_generator( + messages=[], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + ) + + events = [] + async for event in generator: + events.append(event) + + assert len(events) == 1 + assert "event: message_chunk" in events[0] + assert "finish_reason" in events[0] + assert "stop" in events[0] + + @pytest.mark.asyncio + @patch("src.server.app.graph") + async def test_astream_workflow_generator_config_passed_correctly(self, mock_graph): + + mock_ai_message = AIMessageChunk(content="Test") + mock_ai_message.id = "test_id" + mock_ai_message.response_metadata = {} + mock_ai_message.tool_calls = [] + mock_ai_message.tool_call_chunks = [] + + async def verify_config(*args, **kwargs): + config = kwargs.get("config", {}) + assert config["thread_id"] == "test_thread" + assert config["max_plan_iterations"] == 5 + assert config["max_step_num"] == 20 + assert config["max_search_results"] == 10 + assert config["report_style"] == ReportStyle.NEWS.value + yield ("agent1", "messages", [mock_ai_message]) + + +class TestGenerateProseEndpoint: + @patch("src.server.app.build_prose_graph") + def test_generate_prose_success(self, mock_build_graph, client): + # Mock the workflow and its astream method + mock_workflow = MagicMock() + mock_build_graph.return_value = mock_workflow + + class MockEvent: + def __init__(self, content): + self.content = content + + async def mock_astream(*args, **kwargs): + yield (None, [MockEvent("Generated prose 1")]) + yield (None, [MockEvent("Generated prose 2")]) + + mock_workflow.astream.return_value = mock_astream() + request_data = { + "prompt": "Write a story.", + "option": "default", + "command": "generate", + } + + response = client.post("/api/prose/generate", json=request_data) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + + # Read the streaming response content + content = b"".join(response.iter_bytes()) + assert b"Generated prose 1" in content or b"Generated prose 2" in content + + @patch("src.server.app.build_prose_graph") + def test_generate_prose_error(self, mock_build_graph, client): + mock_build_graph.side_effect = Exception("Prose generation failed") + request_data = { + "prompt": "Write a story.", + "option": "default", + "command": "generate", + } + response = client.post("/api/prose/generate", json=request_data) + assert response.status_code == 500 + assert response.json()["detail"] == "Internal Server Error" diff --git a/tests/unit/server/test_chat_request.py b/tests/unit/server/test_chat_request.py new file mode 100644 index 0000000..0556a7c --- /dev/null +++ b/tests/unit/server/test_chat_request.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import asyncio # Ensure asyncio is imported +import pytest +from pydantic import ValidationError +from src.config.report_style import ReportStyle +from src.rag.retriever import Resource +from unittest.mock import AsyncMock, patch, MagicMock +from fastapi import HTTPException + +from src.server.chat_request import ( + ContentItem, + ChatMessage, + ChatRequest, + TTSRequest, + GeneratePodcastRequest, + GeneratePPTRequest, + GenerateProseRequest, + EnhancePromptRequest, +) +import src.server.mcp_utils as mcp_utils # Assuming mcp_utils is the module to test + + +def test_content_item_text_and_image(): + item_text = ContentItem(type="text", text="hello") + assert item_text.type == "text" + assert item_text.text == "hello" + assert item_text.image_url is None + + item_image = ContentItem(type="image", image_url="http://img.com/1.png") + assert item_image.type == "image" + assert item_image.text is None + assert item_image.image_url == "http://img.com/1.png" + + +def test_chat_message_with_string_content(): + msg = ChatMessage(role="user", content="Hello!") + assert msg.role == "user" + assert msg.content == "Hello!" + + +def test_chat_message_with_content_items(): + items = [ContentItem(type="text", text="hi")] + msg = ChatMessage(role="assistant", content=items) + assert msg.role == "assistant" + assert isinstance(msg.content, list) + assert msg.content[0].type == "text" + + +def test_chat_request_defaults(): + req = ChatRequest() + assert req.messages == [] + assert req.resources == [] + assert req.debug is False + assert req.thread_id == "__default__" + assert req.max_plan_iterations == 1 + assert req.max_step_num == 3 + assert req.max_search_results == 3 + assert req.auto_accepted_plan is False + assert req.interrupt_feedback is None + assert req.mcp_settings is None + assert req.enable_background_investigation is True + assert req.report_style == ReportStyle.ACADEMIC + + +def test_chat_request_with_values(): + resource = Resource( + name="test", type="doc", uri="some-uri-value", title="some-title-value" + ) + msg = ChatMessage(role="user", content="hi") + req = ChatRequest( + messages=[msg], + resources=[resource], + debug=True, + thread_id="tid", + max_plan_iterations=2, + max_step_num=5, + max_search_results=10, + auto_accepted_plan=True, + interrupt_feedback="stop", + mcp_settings={"foo": "bar"}, + enable_background_investigation=False, + report_style="academic", + ) + assert req.messages[0].role == "user" + assert req.debug is True + assert req.thread_id == "tid" + assert req.max_plan_iterations == 2 + assert req.max_step_num == 5 + assert req.max_search_results == 10 + assert req.auto_accepted_plan is True + assert req.interrupt_feedback == "stop" + assert req.mcp_settings == {"foo": "bar"} + assert req.enable_background_investigation is False + assert req.report_style == ReportStyle.ACADEMIC + + +def test_tts_request_defaults(): + req = TTSRequest(text="hello") + assert req.text == "hello" + assert req.voice_type == "BV700_V2_streaming" + assert req.encoding == "mp3" + assert req.speed_ratio == 1.0 + assert req.volume_ratio == 1.0 + assert req.pitch_ratio == 1.0 + assert req.text_type == "plain" + assert req.with_frontend == 1 + assert req.frontend_type == "unitTson" + + +def test_generate_podcast_request(): + req = GeneratePodcastRequest(content="Podcast content") + assert req.content == "Podcast content" + + +def test_generate_ppt_request(): + req = GeneratePPTRequest(content="PPT content") + assert req.content == "PPT content" + + +def test_generate_prose_request(): + req = GenerateProseRequest(prompt="Write a poem", option="poet", command="rhyme") + assert req.prompt == "Write a poem" + assert req.option == "poet" + assert req.command == "rhyme" + + req2 = GenerateProseRequest(prompt="Write", option="short") + assert req2.command == "" + + +def test_enhance_prompt_request_defaults(): + req = EnhancePromptRequest(prompt="Improve this") + assert req.prompt == "Improve this" + assert req.context == "" + assert req.report_style == "academic" + + +def test_content_item_validation_error(): + with pytest.raises(ValidationError): + ContentItem() # missing required 'type' + + +def test_chat_message_validation_error(): + with pytest.raises(ValidationError): + ChatMessage(role="user") # missing content + + +def test_tts_request_validation_error(): + with pytest.raises(ValidationError): + TTSRequest() # missing required 'text' + + +@pytest.mark.asyncio +@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock) +@patch("src.server.mcp_utils.StdioServerParameters") +@patch("src.server.mcp_utils.stdio_client") +async def test_load_mcp_tools_exception_handling( + mock_stdio_client, mock_StdioServerParameters, mock_get_tools +): # Changed to async def + mock_get_tools.side_effect = Exception("unexpected error") + mock_StdioServerParameters.return_value = MagicMock() + mock_stdio_client.return_value = MagicMock() + + with pytest.raises(HTTPException) as exc: + await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") # Use await + assert exc.value.status_code == 500 + assert "unexpected error" in exc.value.detail diff --git a/tests/unit/server/test_mcp_request.py b/tests/unit/server/test_mcp_request.py new file mode 100644 index 0000000..eb43128 --- /dev/null +++ b/tests/unit/server/test_mcp_request.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import pytest +from pydantic import ValidationError +from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse + + +def test_mcp_server_metadata_request_required_fields(): + # 'transport' is required + req = MCPServerMetadataRequest(transport="stdio") + assert req.transport == "stdio" + assert req.command is None + assert req.args is None + assert req.url is None + assert req.env is None + assert req.timeout_seconds is None + + +def test_mcp_server_metadata_request_optional_fields(): + req = MCPServerMetadataRequest( + transport="sse", + command="run", + args=["--foo", "bar"], + url="http://localhost:8080", + env={"FOO": "BAR"}, + timeout_seconds=30, + ) + assert req.transport == "sse" + assert req.command == "run" + assert req.args == ["--foo", "bar"] + assert req.url == "http://localhost:8080" + assert req.env == {"FOO": "BAR"} + assert req.timeout_seconds == 30 + + +def test_mcp_server_metadata_request_missing_transport(): + with pytest.raises(ValidationError): + MCPServerMetadataRequest() + + +def test_mcp_server_metadata_response_required_fields(): + resp = MCPServerMetadataResponse(transport="stdio") + assert resp.transport == "stdio" + assert resp.command is None + assert resp.args is None + assert resp.url is None + assert resp.env is None + assert resp.tools == [] + + +def test_mcp_server_metadata_response_optional_fields(): + resp = MCPServerMetadataResponse( + transport="sse", + command="run", + args=["--foo", "bar"], + url="http://localhost:8080", + env={"FOO": "BAR"}, + tools=["tool1", "tool2"], + ) + assert resp.transport == "sse" + assert resp.command == "run" + assert resp.args == ["--foo", "bar"] + assert resp.url == "http://localhost:8080" + assert resp.env == {"FOO": "BAR"} + assert resp.tools == ["tool1", "tool2"] + + +def test_mcp_server_metadata_response_tools_default_factory(): + resp1 = MCPServerMetadataResponse(transport="stdio") + resp2 = MCPServerMetadataResponse(transport="stdio") + resp1.tools.append("toolA") + assert resp2.tools == [] # Should not share list between instances diff --git a/tests/unit/server/test_mcp_utils.py b/tests/unit/server/test_mcp_utils.py new file mode 100644 index 0000000..3975b70 --- /dev/null +++ b/tests/unit/server/test_mcp_utils.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from fastapi import HTTPException + +import src.server.mcp_utils as mcp_utils + + +@pytest.mark.asyncio +@patch("src.server.mcp_utils.ClientSession") +async def test__get_tools_from_client_session_success(mock_ClientSession): + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_context_manager = AsyncMock() + mock_context_manager.__aenter__.return_value = (mock_read, mock_write) + mock_context_manager.__aexit__.return_value = None + + mock_session = AsyncMock() + mock_session.__aenter__.return_value = mock_session + mock_session.__aexit__.return_value = None + mock_session.initialize = AsyncMock() + mock_tools_obj = MagicMock() + mock_tools_obj.tools = ["tool1", "tool2"] + mock_session.list_tools = AsyncMock(return_value=mock_tools_obj) + mock_ClientSession.return_value = mock_session + + result = await mcp_utils._get_tools_from_client_session( + mock_context_manager, timeout_seconds=5 + ) + assert result == ["tool1", "tool2"] + mock_session.initialize.assert_awaited_once() + mock_session.list_tools.assert_awaited_once() + + +@pytest.mark.asyncio +@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock) +@patch("src.server.mcp_utils.StdioServerParameters") +@patch("src.server.mcp_utils.stdio_client") +async def test_load_mcp_tools_stdio_success( + mock_stdio_client, mock_StdioServerParameters, mock_get_tools +): + mock_get_tools.return_value = ["toolA"] + params = MagicMock() + mock_StdioServerParameters.return_value = params + mock_client = MagicMock() + mock_stdio_client.return_value = mock_client + + result = await mcp_utils.load_mcp_tools( + server_type="stdio", + command="echo", + args=["foo"], + env={"FOO": "BAR"}, + timeout_seconds=3, + ) + assert result == ["toolA"] + mock_StdioServerParameters.assert_called_once_with( + command="echo", args=["foo"], env={"FOO": "BAR"} + ) + mock_stdio_client.assert_called_once_with(params) + mock_get_tools.assert_awaited_once_with(mock_client, 3) + + +@pytest.mark.asyncio +async def test_load_mcp_tools_stdio_missing_command(): + with pytest.raises(HTTPException) as exc: + await mcp_utils.load_mcp_tools(server_type="stdio") + assert exc.value.status_code == 400 + assert "Command is required" in exc.value.detail + + +@pytest.mark.asyncio +@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock) +@patch("src.server.mcp_utils.sse_client") +async def test_load_mcp_tools_sse_success(mock_sse_client, mock_get_tools): + mock_get_tools.return_value = ["toolB"] + mock_client = MagicMock() + mock_sse_client.return_value = mock_client + + result = await mcp_utils.load_mcp_tools( + server_type="sse", + url="http://localhost:1234", + timeout_seconds=7, + ) + assert result == ["toolB"] + mock_sse_client.assert_called_once_with(url="http://localhost:1234") + mock_get_tools.assert_awaited_once_with(mock_client, 7) + + +@pytest.mark.asyncio +async def test_load_mcp_tools_sse_missing_url(): + with pytest.raises(HTTPException) as exc: + await mcp_utils.load_mcp_tools(server_type="sse") + assert exc.value.status_code == 400 + assert "URL is required" in exc.value.detail + + +@pytest.mark.asyncio +async def test_load_mcp_tools_unsupported_type(): + with pytest.raises(HTTPException) as exc: + await mcp_utils.load_mcp_tools(server_type="unknown") + assert exc.value.status_code == 400 + assert "Unsupported server type" in exc.value.detail + + +@pytest.mark.asyncio +@patch("src.server.mcp_utils._get_tools_from_client_session", new_callable=AsyncMock) +@patch("src.server.mcp_utils.StdioServerParameters") +@patch("src.server.mcp_utils.stdio_client") +async def test_load_mcp_tools_exception_handling( + mock_stdio_client, mock_StdioServerParameters, mock_get_tools +): + mock_get_tools.side_effect = Exception("unexpected error") + mock_StdioServerParameters.return_value = MagicMock() + mock_stdio_client.return_value = MagicMock() + + with pytest.raises(HTTPException) as exc: + await mcp_utils.load_mcp_tools(server_type="stdio", command="foo") + assert exc.value.status_code == 500 + assert "unexpected error" in exc.value.detail