mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
* fix: strip <think> tags from LLM output to prevent thinking text leakage (#781) Some models (e.g. DeepSeek-R1, QwQ via ollama) embed reasoning in content using <think>...</think> tags instead of the separate reasoning_content field. This causes thinking text to leak into both streamed messages and the final report. Fix at two layers: - server/app.py: strip <think> tags in _create_event_stream_message so ALL streamed content is filtered (coordinator, planner, etc.) - graph/nodes.py: strip <think> tags in reporter_node before storing final_report (which is not streamed through the event layer) The regex uses a fast-path check ("<think>" in content) to avoid unnecessary regex calls on normal content. * refactor: add defensive check for think tag stripping and add reporter_node tests (#781) - Add isinstance and fast-path check in reporter_node before regex, consistent with app.py - Add TestReporterNodeThinkTagStripping with 5 test cases covering various scenarios * chore: re-trigger review
1734 lines
62 KiB
Python
1734 lines
62 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
import asyncio
|
|
import base64
|
|
import os
|
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
from fastapi.testclient import TestClient
|
|
from langchain_core.messages import AIMessageChunk, ToolMessage
|
|
from langgraph.types import Command
|
|
|
|
from src.config.report_style import ReportStyle
|
|
from src.server.app import (
|
|
_astream_workflow_generator,
|
|
_create_event_stream_message,
|
|
_create_interrupt_event,
|
|
_make_event,
|
|
_stream_graph_events,
|
|
app,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
return TestClient(app)
|
|
|
|
|
|
class TestMakeEvent:
|
|
def test_make_event_with_content(self):
|
|
event_type = "message_chunk"
|
|
data = {"content": "Hello", "role": "assistant"}
|
|
result = _make_event(event_type, data)
|
|
expected = (
|
|
'event: message_chunk\ndata: {"content": "Hello", "role": "assistant"}\n\n'
|
|
)
|
|
assert result == expected
|
|
|
|
def test_make_event_with_empty_content(self):
|
|
event_type = "message_chunk"
|
|
data = {"content": "", "role": "assistant"}
|
|
result = _make_event(event_type, data)
|
|
expected = 'event: message_chunk\ndata: {"role": "assistant"}\n\n'
|
|
assert result == expected
|
|
|
|
def test_make_event_without_content(self):
|
|
event_type = "tool_calls"
|
|
data = {"role": "assistant", "tool_calls": []}
|
|
result = _make_event(event_type, data)
|
|
expected = (
|
|
'event: tool_calls\ndata: {"role": "assistant", "tool_calls": []}\n\n'
|
|
)
|
|
assert result == expected
|
|
|
|
|
|
class TestStreamGraphEventsCancellation:
|
|
"""Tests for graceful handling of asyncio.CancelledError in _stream_graph_events."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancelled_error_does_not_propagate(self):
|
|
"""When the stream is cancelled, the generator should end gracefully
|
|
instead of re-raising CancelledError (fixes issue #847)."""
|
|
|
|
async def _mock_astream(*args, **kwargs):
|
|
yield ("agent", None, {"some": "data"})
|
|
raise asyncio.CancelledError()
|
|
|
|
graph = MagicMock()
|
|
graph.astream = _mock_astream
|
|
|
|
events = []
|
|
# The generator must NOT raise CancelledError
|
|
async for event in _stream_graph_events(
|
|
graph, {"input": "test"}, {}, "test-thread-id"
|
|
):
|
|
events.append(event)
|
|
|
|
# It should have yielded a final error event with reason='cancelled'
|
|
final_events_with_cancelled = [
|
|
e for e in events if '"reason": "cancelled"' in e
|
|
]
|
|
assert len(final_events_with_cancelled) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancelled_error_yields_cancelled_reason(self):
|
|
"""The final event should carry reason='cancelled' so the client
|
|
can distinguish cancellation from real errors."""
|
|
|
|
async def _mock_astream(*args, **kwargs):
|
|
raise asyncio.CancelledError()
|
|
yield # make this an async generator # noqa: E501
|
|
|
|
graph = MagicMock()
|
|
graph.astream = _mock_astream
|
|
|
|
events = []
|
|
async for event in _stream_graph_events(
|
|
graph, {"input": "test"}, {}, "test-thread-id"
|
|
):
|
|
events.append(event)
|
|
|
|
assert len(events) == 1
|
|
assert '"reason": "cancelled"' in events[0]
|
|
assert '"error": "Stream cancelled"' in events[0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_astream_workflow_generator_preserves_clarification_history():
|
|
messages = [
|
|
{"role": "user", "content": "Research on renewable energy"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "What type of renewable energy would you like to know about?",
|
|
},
|
|
{"role": "user", "content": "Solar and wind energy"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Please tell me the research dimensions you focus on, such as technological development or market applications.",
|
|
},
|
|
{"role": "user", "content": "Technological development"},
|
|
{
|
|
"role": "assistant",
|
|
"content": "Please specify the time range you want to focus on, such as current status or future trends.",
|
|
},
|
|
{"role": "user", "content": "Current status and future trends"},
|
|
]
|
|
|
|
captured_data = {}
|
|
|
|
def empty_async_iterator(*args, **kwargs):
|
|
captured_data["workflow_input"] = args[1]
|
|
captured_data["workflow_config"] = args[2]
|
|
|
|
class IteratorObject:
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
raise StopAsyncIteration
|
|
|
|
return IteratorObject()
|
|
|
|
with (
|
|
patch("src.server.app._process_initial_messages"),
|
|
patch("src.server.app._stream_graph_events", side_effect=empty_async_iterator),
|
|
):
|
|
generator = _astream_workflow_generator(
|
|
messages=messages,
|
|
thread_id="clarification-thread",
|
|
resources=[],
|
|
max_plan_iterations=1,
|
|
max_step_num=1,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=True,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=True,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
with pytest.raises(StopAsyncIteration):
|
|
await generator.__anext__()
|
|
|
|
workflow_input = captured_data["workflow_input"]
|
|
assert workflow_input["clarification_history"] == [
|
|
"Research on renewable energy",
|
|
"Solar and wind energy",
|
|
"Technological development",
|
|
"Current status and future trends",
|
|
]
|
|
assert (
|
|
workflow_input["clarified_research_topic"]
|
|
== "Research on renewable energy - Solar and wind energy, Technological development, Current status and future trends"
|
|
)
|
|
|
|
|
|
class TestTTSEndpoint:
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"VOLCENGINE_TTS_APPID": "test_app_id",
|
|
"VOLCENGINE_TTS_ACCESS_TOKEN": "test_token",
|
|
"VOLCENGINE_TTS_CLUSTER": "test_cluster",
|
|
"VOLCENGINE_TTS_VOICE_TYPE": "test_voice",
|
|
},
|
|
)
|
|
@patch("src.server.app.VolcengineTTS")
|
|
def test_tts_success(self, mock_tts_class, client):
|
|
mock_tts_instance = MagicMock()
|
|
mock_tts_class.return_value = mock_tts_instance
|
|
|
|
# Mock successful TTS response
|
|
audio_data_b64 = base64.b64encode(b"fake_audio_data").decode()
|
|
mock_tts_instance.text_to_speech.return_value = {
|
|
"success": True,
|
|
"audio_data": audio_data_b64,
|
|
}
|
|
|
|
request_data = {
|
|
"text": "Hello world",
|
|
"encoding": "mp3",
|
|
"speed_ratio": 1.0,
|
|
"volume_ratio": 1.0,
|
|
"pitch_ratio": 1.0,
|
|
"text_type": "plain",
|
|
"with_frontend": True,
|
|
"frontend_type": "unitTson",
|
|
}
|
|
|
|
response = client.post("/api/tts", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "audio/mp3"
|
|
assert b"fake_audio_data" in response.content
|
|
|
|
@patch.dict(os.environ, {}, clear=True)
|
|
def test_tts_missing_app_id(self, client):
|
|
request_data = {"text": "Hello world", "encoding": "mp3"}
|
|
|
|
response = client.post("/api/tts", json=request_data)
|
|
|
|
assert response.status_code == 400
|
|
assert "VOLCENGINE_TTS_APPID is not set" in response.json()["detail"]
|
|
|
|
@patch.dict(
|
|
os.environ,
|
|
{"VOLCENGINE_TTS_APPID": "test_app_id", "VOLCENGINE_TTS_ACCESS_TOKEN": ""},
|
|
)
|
|
def test_tts_missing_access_token(self, client):
|
|
request_data = {"text": "Hello world", "encoding": "mp3"}
|
|
|
|
response = client.post("/api/tts", json=request_data)
|
|
|
|
assert response.status_code == 400
|
|
assert "VOLCENGINE_TTS_ACCESS_TOKEN is not set" in response.json()["detail"]
|
|
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"VOLCENGINE_TTS_APPID": "test_app_id",
|
|
"VOLCENGINE_TTS_ACCESS_TOKEN": "test_token",
|
|
},
|
|
)
|
|
@patch("src.server.app.VolcengineTTS")
|
|
def test_tts_api_error(self, mock_tts_class, client):
|
|
mock_tts_instance = MagicMock()
|
|
mock_tts_class.return_value = mock_tts_instance
|
|
|
|
# Mock TTS error response
|
|
mock_tts_instance.text_to_speech.return_value = {
|
|
"success": False,
|
|
"error": "TTS API error",
|
|
}
|
|
|
|
request_data = {"text": "Hello world", "encoding": "mp3"}
|
|
|
|
response = client.post("/api/tts", json=request_data)
|
|
|
|
assert response.status_code == 500
|
|
assert "Internal Server Error" in response.json()["detail"]
|
|
|
|
@pytest.mark.skip(reason="TTS server exception is catched")
|
|
@patch("src.server.app.VolcengineTTS")
|
|
def test_tts_api_exception(self, mock_tts_class, client):
|
|
mock_tts_instance = MagicMock()
|
|
mock_tts_class.return_value = mock_tts_instance
|
|
|
|
# Mock TTS error response
|
|
mock_tts_instance.side_effect = Exception("TTS API error")
|
|
|
|
request_data = {"text": "Hello world", "encoding": "mp3"}
|
|
|
|
response = client.post("/api/tts", json=request_data)
|
|
|
|
assert response.status_code == 500
|
|
assert "Internal Server Error" in response.json()["detail"]
|
|
|
|
|
|
class TestPodcastEndpoint:
|
|
@patch("src.server.app.build_podcast_graph")
|
|
def test_generate_podcast_success(self, mock_build_graph, client):
|
|
mock_workflow = MagicMock()
|
|
mock_build_graph.return_value = mock_workflow
|
|
mock_workflow.invoke.return_value = {"output": b"fake_audio_data"}
|
|
|
|
request_data = {"content": "Test content for podcast"}
|
|
|
|
response = client.post("/api/podcast/generate", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "audio/mp3"
|
|
assert response.content == b"fake_audio_data"
|
|
|
|
@patch("src.server.app.build_podcast_graph")
|
|
def test_generate_podcast_error(self, mock_build_graph, client):
|
|
mock_build_graph.side_effect = Exception("Podcast generation failed")
|
|
|
|
request_data = {"content": "Test content"}
|
|
|
|
response = client.post("/api/podcast/generate", json=request_data)
|
|
|
|
assert response.status_code == 500
|
|
assert response.json()["detail"] == "Internal Server Error"
|
|
|
|
|
|
class TestPPTEndpoint:
|
|
@patch("src.server.app.build_ppt_graph")
|
|
@patch("builtins.open", new_callable=mock_open, read_data=b"fake_ppt_data")
|
|
def test_generate_ppt_success(self, mock_file, mock_build_graph, client):
|
|
mock_workflow = MagicMock()
|
|
mock_build_graph.return_value = mock_workflow
|
|
mock_workflow.invoke.return_value = {
|
|
"generated_file_path": "/fake/path/test.pptx"
|
|
}
|
|
|
|
request_data = {"content": "Test content for PPT"}
|
|
|
|
response = client.post("/api/ppt/generate", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
assert (
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
|
in response.headers["content-type"]
|
|
)
|
|
assert response.content == b"fake_ppt_data"
|
|
|
|
@patch("src.server.app.build_ppt_graph")
|
|
def test_generate_ppt_error(self, mock_build_graph, client):
|
|
mock_build_graph.side_effect = Exception("PPT generation failed")
|
|
|
|
request_data = {"content": "Test content"}
|
|
|
|
response = client.post("/api/ppt/generate", json=request_data)
|
|
|
|
assert response.status_code == 500
|
|
assert response.json()["detail"] == "Internal Server Error"
|
|
|
|
|
|
class TestEnhancePromptEndpoint:
|
|
@patch("src.server.app.build_prompt_enhancer_graph")
|
|
def test_enhance_prompt_success(self, mock_build_graph, client):
|
|
mock_workflow = MagicMock()
|
|
mock_build_graph.return_value = mock_workflow
|
|
mock_workflow.invoke.return_value = {"output": "Enhanced prompt"}
|
|
|
|
request_data = {
|
|
"prompt": "Original prompt",
|
|
"context": "Some context",
|
|
"report_style": "academic",
|
|
}
|
|
|
|
response = client.post("/api/prompt/enhance", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["result"] == "Enhanced prompt"
|
|
|
|
@patch("src.server.app.build_prompt_enhancer_graph")
|
|
def test_enhance_prompt_with_different_styles(self, mock_build_graph, client):
|
|
mock_workflow = MagicMock()
|
|
mock_build_graph.return_value = mock_workflow
|
|
mock_workflow.invoke.return_value = {"output": "Enhanced prompt"}
|
|
|
|
styles = [
|
|
"ACADEMIC",
|
|
"popular_science",
|
|
"NEWS",
|
|
"social_media",
|
|
"invalid_style",
|
|
]
|
|
|
|
for style in styles:
|
|
request_data = {"prompt": "Test prompt", "report_style": style}
|
|
|
|
response = client.post("/api/prompt/enhance", json=request_data)
|
|
assert response.status_code == 200
|
|
|
|
@patch("src.server.app.build_prompt_enhancer_graph")
|
|
def test_enhance_prompt_error(self, mock_build_graph, client):
|
|
mock_build_graph.side_effect = Exception("Enhancement failed")
|
|
|
|
request_data = {"prompt": "Test prompt"}
|
|
|
|
response = client.post("/api/prompt/enhance", json=request_data)
|
|
|
|
assert response.status_code == 500
|
|
assert response.json()["detail"] == "Internal Server Error"
|
|
|
|
|
|
class TestMCPEndpoint:
|
|
@patch("src.server.app.load_mcp_tools")
|
|
@patch.dict(
|
|
os.environ,
|
|
{"ENABLE_MCP_SERVER_CONFIGURATION": "true"},
|
|
)
|
|
def test_mcp_server_metadata_success(self, mock_load_tools, client):
|
|
mock_load_tools.return_value = [
|
|
{"name": "test_tool", "description": "Test tool"}
|
|
]
|
|
|
|
request_data = {
|
|
"transport": "stdio",
|
|
"command": "node",
|
|
"args": ["server.js"],
|
|
"env": {"API_KEY": "test123"},
|
|
}
|
|
|
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
response_data = response.json()
|
|
assert response_data["transport"] == "stdio"
|
|
assert response_data["command"] == "node"
|
|
assert len(response_data["tools"]) == 1
|
|
|
|
@patch("src.server.app.load_mcp_tools")
|
|
@patch.dict(
|
|
os.environ,
|
|
{"ENABLE_MCP_SERVER_CONFIGURATION": "true"},
|
|
)
|
|
def test_mcp_server_metadata_with_custom_timeout(self, mock_load_tools, client):
|
|
mock_load_tools.return_value = []
|
|
|
|
request_data = {
|
|
"transport": "stdio",
|
|
"command": "node",
|
|
"timeout_seconds": 60,
|
|
}
|
|
|
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
mock_load_tools.assert_called_once()
|
|
# Verify timeout_seconds is passed to load_mcp_tools
|
|
call_kwargs = mock_load_tools.call_args[1]
|
|
assert call_kwargs["timeout_seconds"] == 60
|
|
|
|
@patch("src.server.app.load_mcp_tools")
|
|
@patch.dict(
|
|
os.environ,
|
|
{"ENABLE_MCP_SERVER_CONFIGURATION": "true"},
|
|
)
|
|
def test_mcp_server_metadata_with_sse_read_timeout(self, mock_load_tools, client):
|
|
"""Test that sse_read_timeout is passed to load_mcp_tools."""
|
|
mock_load_tools.return_value = []
|
|
|
|
request_data = {
|
|
"transport": "sse",
|
|
"url": "http://localhost:3000/sse",
|
|
"timeout_seconds": 30,
|
|
"sse_read_timeout": 15,
|
|
}
|
|
|
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
mock_load_tools.assert_called_once()
|
|
# Verify both timeout_seconds and sse_read_timeout are passed
|
|
call_kwargs = mock_load_tools.call_args[1]
|
|
assert call_kwargs["timeout_seconds"] == 30
|
|
assert call_kwargs["sse_read_timeout"] == 15
|
|
|
|
@patch("src.server.app.load_mcp_tools")
|
|
@patch.dict(
|
|
os.environ,
|
|
{"ENABLE_MCP_SERVER_CONFIGURATION": "true"},
|
|
)
|
|
def test_mcp_server_metadata_with_exception(self, mock_load_tools, client):
|
|
mock_load_tools.side_effect = HTTPException(
|
|
status_code=400, detail="MCP Server Error"
|
|
)
|
|
|
|
request_data = {
|
|
"transport": "stdio",
|
|
"command": "node",
|
|
"args": ["server.js"],
|
|
"env": {"API_KEY": "test123"},
|
|
}
|
|
|
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
|
|
|
assert response.status_code == 500
|
|
assert response.json()["detail"] == "Internal Server Error"
|
|
|
|
@patch("src.server.app.load_mcp_tools")
|
|
@patch.dict(
|
|
os.environ,
|
|
{"ENABLE_MCP_SERVER_CONFIGURATION": ""},
|
|
)
|
|
def test_mcp_server_metadata_without_enable_configuration(
|
|
self, mock_load_tools, client
|
|
):
|
|
request_data = {
|
|
"transport": "stdio",
|
|
"command": "node",
|
|
"args": ["server.js"],
|
|
"env": {"API_KEY": "test123"},
|
|
}
|
|
|
|
response = client.post("/api/mcp/server/metadata", json=request_data)
|
|
|
|
assert response.status_code == 403
|
|
assert (
|
|
response.json()["detail"]
|
|
== "MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features."
|
|
)
|
|
|
|
|
|
class TestRAGEndpoints:
|
|
@patch("src.server.app.SELECTED_RAG_PROVIDER", "test_provider")
|
|
def test_rag_config(self, client):
|
|
response = client.get("/api/rag/config")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["provider"] == "test_provider"
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_rag_resources_with_retriever(self, mock_build_retriever, client):
|
|
mock_retriever = MagicMock()
|
|
mock_retriever.list_resources.return_value = [
|
|
{
|
|
"uri": "test_uri",
|
|
"title": "Test Resource",
|
|
"description": "Test Description",
|
|
}
|
|
]
|
|
mock_build_retriever.return_value = mock_retriever
|
|
|
|
response = client.get("/api/rag/resources?query=test")
|
|
|
|
assert response.status_code == 200
|
|
assert len(response.json()["resources"]) == 1
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_rag_resources_without_retriever(self, mock_build_retriever, client):
|
|
mock_build_retriever.return_value = None
|
|
|
|
response = client.get("/api/rag/resources")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["resources"] == []
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_upload_rag_resource_success(self, mock_build_retriever, client):
|
|
mock_retriever = MagicMock()
|
|
mock_retriever.ingest_file.return_value = {
|
|
"uri": "milvus://test/file.md",
|
|
"title": "Test File",
|
|
"description": "Uploaded file",
|
|
}
|
|
mock_build_retriever.return_value = mock_retriever
|
|
|
|
files = {"file": ("test.md", b"# Test content", "text/markdown")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["title"] == "Test File"
|
|
assert response.json()["uri"] == "milvus://test/file.md"
|
|
mock_retriever.ingest_file.assert_called_once()
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_upload_rag_resource_no_retriever(self, mock_build_retriever, client):
|
|
mock_build_retriever.return_value = None
|
|
|
|
files = {"file": ("test.md", b"# Test content", "text/markdown")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 500
|
|
assert "RAG provider not configured" in response.json()["detail"]
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_upload_rag_resource_not_implemented(self, mock_build_retriever, client):
|
|
mock_retriever = MagicMock()
|
|
mock_retriever.ingest_file.side_effect = NotImplementedError
|
|
mock_build_retriever.return_value = mock_retriever
|
|
|
|
files = {"file": ("test.md", b"# Test content", "text/markdown")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 501
|
|
assert "Upload not supported" in response.json()["detail"]
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_upload_rag_resource_value_error(self, mock_build_retriever, client):
|
|
mock_retriever = MagicMock()
|
|
mock_retriever.ingest_file.side_effect = ValueError("File is not valid UTF-8")
|
|
mock_build_retriever.return_value = mock_retriever
|
|
|
|
files = {"file": ("test.txt", b"\x80\x81\x82", "text/plain")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 400
|
|
assert "Invalid RAG resource" in response.json()["detail"]
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_upload_rag_resource_runtime_error(self, mock_build_retriever, client):
|
|
mock_retriever = MagicMock()
|
|
mock_retriever.ingest_file.side_effect = RuntimeError("Failed to insert into Milvus")
|
|
mock_build_retriever.return_value = mock_retriever
|
|
|
|
files = {"file": ("test.md", b"# Test content", "text/markdown")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 500
|
|
assert "Failed to ingest RAG resource" in response.json()["detail"]
|
|
|
|
def test_upload_rag_resource_invalid_file_type(self, client):
|
|
files = {"file": ("test.exe", b"binary content", "application/octet-stream")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 400
|
|
assert "Invalid file type" in response.json()["detail"]
|
|
|
|
def test_upload_rag_resource_empty_file(self, client):
|
|
files = {"file": ("test.md", b"", "text/markdown")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 400
|
|
assert "empty file" in response.json()["detail"]
|
|
|
|
@patch("src.server.app.MAX_UPLOAD_SIZE_BYTES", 10)
|
|
def test_upload_rag_resource_file_too_large(self, client):
|
|
files = {"file": ("test.md", b"x" * 100, "text/markdown")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 413
|
|
assert "File too large" in response.json()["detail"]
|
|
|
|
@patch("src.server.app.build_retriever")
|
|
def test_upload_rag_resource_path_traversal_sanitized(self, mock_build_retriever, client):
|
|
mock_retriever = MagicMock()
|
|
mock_retriever.ingest_file.return_value = {
|
|
"uri": "milvus://test/file.md",
|
|
"title": "Test File",
|
|
"description": "Uploaded file",
|
|
}
|
|
mock_build_retriever.return_value = mock_retriever
|
|
|
|
files = {"file": ("../../../etc/passwd.md", b"# Test", "text/markdown")}
|
|
response = client.post("/api/rag/upload", files=files)
|
|
|
|
assert response.status_code == 200
|
|
# Verify the filename was sanitized (only basename used)
|
|
mock_retriever.ingest_file.assert_called_once()
|
|
call_args = mock_retriever.ingest_file.call_args
|
|
assert call_args[0][1] == "passwd.md"
|
|
|
|
|
|
class TestChatStreamEndpoint:
|
|
@patch("src.server.app.graph")
|
|
def test_chat_stream_with_default_thread_id(self, mock_graph, client):
|
|
# Mock the async stream
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
request_data = {
|
|
"thread_id": "__default__",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"resources": [],
|
|
"max_plan_iterations": 3,
|
|
"max_step_num": 10,
|
|
"max_search_results": 5,
|
|
"auto_accepted_plan": True,
|
|
"interrupt_feedback": "",
|
|
"mcp_settings": {},
|
|
"enable_background_investigation": False,
|
|
"report_style": "academic",
|
|
}
|
|
|
|
response = client.post("/api/chat/stream", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
@patch("src.server.app.graph")
|
|
def test_chat_stream_with_mcp_settings(self, mock_graph, client):
|
|
# Mock the async stream
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
request_data = {
|
|
"thread_id": "__default__",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"resources": [],
|
|
"max_plan_iterations": 3,
|
|
"max_step_num": 10,
|
|
"max_search_results": 5,
|
|
"auto_accepted_plan": True,
|
|
"interrupt_feedback": "",
|
|
"mcp_settings": {
|
|
"servers": {
|
|
"mcp-github-trending": {
|
|
"transport": "stdio",
|
|
"command": "uvx",
|
|
"args": ["mcp-github-trending"],
|
|
"env": {"MCP_SERVER_ID": "mcp-github-trending"},
|
|
"enabled_tools": ["get_github_trending_repositories"],
|
|
"add_to_agents": ["researcher"],
|
|
}
|
|
}
|
|
},
|
|
"enable_background_investigation": False,
|
|
"report_style": "academic",
|
|
}
|
|
|
|
response = client.post("/api/chat/stream", json=request_data)
|
|
|
|
assert response.status_code == 403
|
|
assert (
|
|
response.json()["detail"]
|
|
== "MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features."
|
|
)
|
|
|
|
@patch("src.server.app.graph")
|
|
@patch.dict(
|
|
os.environ,
|
|
{"ENABLE_MCP_SERVER_CONFIGURATION": "true"},
|
|
)
|
|
def test_chat_stream_with_mcp_settings_enabled(self, mock_graph, client):
|
|
# Mock the async stream
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
request_data = {
|
|
"thread_id": "__default__",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"resources": [],
|
|
"max_plan_iterations": 3,
|
|
"max_step_num": 10,
|
|
"max_search_results": 5,
|
|
"auto_accepted_plan": True,
|
|
"interrupt_feedback": "",
|
|
"mcp_settings": {
|
|
"servers": {
|
|
"mcp-github-trending": {
|
|
"transport": "stdio",
|
|
"command": "uvx",
|
|
"args": ["mcp-github-trending"],
|
|
"env": {"MCP_SERVER_ID": "mcp-github-trending"},
|
|
"enabled_tools": ["get_github_trending_repositories"],
|
|
"add_to_agents": ["researcher"],
|
|
}
|
|
}
|
|
},
|
|
"enable_background_investigation": False,
|
|
"report_style": "academic",
|
|
}
|
|
|
|
response = client.post("/api/chat/stream", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
|
|
|
|
class TestAstreamWorkflowGenerator:
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_basic_flow(self, mock_graph):
|
|
# Mock AI message chunk
|
|
mock_message = AIMessageChunk(content="Hello world")
|
|
mock_message.id = "msg_123"
|
|
mock_message.response_metadata = {}
|
|
mock_message.tool_calls = []
|
|
mock_message.tool_call_chunks = []
|
|
|
|
# Mock the async stream - yield messages in the correct format
|
|
async def mock_astream(*args, **kwargs):
|
|
# Yield a tuple (message, metadata) instead of just [message]
|
|
yield ("agent1:subagent", "messages", (mock_message, {}))
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
thread_id = "test_thread"
|
|
resources = []
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=messages,
|
|
thread_id=thread_id,
|
|
resources=resources,
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
events = []
|
|
async for event in generator:
|
|
events.append(event)
|
|
|
|
assert len(events) == 1
|
|
assert "event: message_chunk" in events[0]
|
|
assert "Hello world" in events[0]
|
|
# Check for the actual agent name that appears in the output
|
|
assert '"agent": "a"' in events[0]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_with_interrupt_feedback(self, mock_graph):
|
|
# Mock the async stream
|
|
async def mock_astream(*args, **kwargs):
|
|
# Verify that Command is passed as input when interrupt_feedback is provided
|
|
assert isinstance(args[0], Command)
|
|
assert "[edit_plan] Hello" in args[0].resume
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
messages = [{"role": "user", "content": "Hello"}]
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=messages,
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=False,
|
|
interrupt_feedback="edit_plan",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
events = []
|
|
async for event in generator:
|
|
events.append(event)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_interrupt_event(self, mock_graph):
|
|
# Mock interrupt data with the new 'id' attribute (LangGraph 1.0+)
|
|
mock_interrupt = MagicMock()
|
|
mock_interrupt.id = "interrupt_id"
|
|
mock_interrupt.value = "Plan requires approval"
|
|
|
|
interrupt_data = {"__interrupt__": [mock_interrupt]}
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", interrupt_data)
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
events = []
|
|
async for event in generator:
|
|
events.append(event)
|
|
|
|
assert len(events) == 1
|
|
assert "event: interrupt" in events[0]
|
|
assert "Plan requires approval" in events[0]
|
|
assert "interrupt_id" in events[0]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_tool_message(self, mock_graph):
|
|
# Mock tool message
|
|
mock_tool_message = ToolMessage(content="Tool result", tool_call_id="tool_123")
|
|
mock_tool_message.id = "msg_456"
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1:subagent", "step1", (mock_tool_message, {}))
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
events = []
|
|
async for event in generator:
|
|
events.append(event)
|
|
|
|
assert len(events) == 1
|
|
assert "event: tool_call_result" in events[0]
|
|
assert "Tool result" in events[0]
|
|
assert "tool_123" in events[0]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_ai_message_with_tool_calls(
|
|
self, mock_graph
|
|
):
|
|
# Mock AI message with tool calls
|
|
mock_ai_message = AIMessageChunk(content="Making tool call")
|
|
mock_ai_message.id = "msg_789"
|
|
mock_ai_message.response_metadata = {"finish_reason": "tool_calls"}
|
|
mock_ai_message.tool_calls = [{"name": "search", "args": {"query": "test"}}]
|
|
mock_ai_message.tool_call_chunks = [{"name": "search"}]
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1:subagent", "step1", (mock_ai_message, {}))
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
events = []
|
|
async for event in generator:
|
|
events.append(event)
|
|
|
|
assert len(events) == 1
|
|
assert "event: tool_calls" in events[0]
|
|
assert "Making tool call" in events[0]
|
|
assert "tool_calls" in events[0]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_ai_message_with_tool_call_chunks(
|
|
self, mock_graph
|
|
):
|
|
# Mock AI message with only tool call chunks
|
|
mock_ai_message = AIMessageChunk(content="Streaming tool call")
|
|
mock_ai_message.id = "msg_101"
|
|
mock_ai_message.response_metadata = {}
|
|
mock_ai_message.tool_calls = []
|
|
mock_ai_message.tool_call_chunks = [{"name": "search", "index": 0}]
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1:subagent", "step1", (mock_ai_message, {}))
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
events = []
|
|
async for event in generator:
|
|
events.append(event)
|
|
|
|
assert len(events) == 1
|
|
assert "event: tool_call_chunks" in events[0]
|
|
assert "Streaming tool call" in events[0]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_with_finish_reason(self, mock_graph):
|
|
# Mock AI message with finish reason
|
|
mock_ai_message = AIMessageChunk(content="Complete response")
|
|
mock_ai_message.id = "msg_finish"
|
|
mock_ai_message.response_metadata = {"finish_reason": "stop"}
|
|
mock_ai_message.tool_calls = []
|
|
mock_ai_message.tool_call_chunks = []
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1:subagent", "step1", (mock_ai_message, {}))
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
events = []
|
|
async for event in generator:
|
|
events.append(event)
|
|
|
|
assert len(events) == 1
|
|
assert "event: message_chunk" in events[0]
|
|
assert "finish_reason" in events[0]
|
|
assert "stop" in events[0]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_workflow_generator_config_passed_correctly(self, mock_graph):
|
|
mock_ai_message = AIMessageChunk(content="Test")
|
|
mock_ai_message.id = "test_id"
|
|
mock_ai_message.response_metadata = {}
|
|
mock_ai_message.tool_calls = []
|
|
mock_ai_message.tool_call_chunks = []
|
|
|
|
async def verify_config(*args, **kwargs):
|
|
config = kwargs.get("config", {})
|
|
assert config["thread_id"] == "test_thread"
|
|
assert config["max_plan_iterations"] == 5
|
|
assert config["max_step_num"] == 20
|
|
assert config["max_search_results"] == 10
|
|
assert config["report_style"] == ReportStyle.NEWS.value
|
|
yield ("agent1", "messages", [mock_ai_message])
|
|
|
|
|
|
class TestGenerateProseEndpoint:
|
|
@patch("src.server.app.build_prose_graph")
|
|
def test_generate_prose_success(self, mock_build_graph, client):
|
|
# Mock the workflow and its astream method
|
|
mock_workflow = MagicMock()
|
|
mock_build_graph.return_value = mock_workflow
|
|
|
|
class MockEvent:
|
|
def __init__(self, content):
|
|
self.content = content
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield (None, [MockEvent("Generated prose 1")])
|
|
yield (None, [MockEvent("Generated prose 2")])
|
|
|
|
mock_workflow.astream.return_value = mock_astream()
|
|
request_data = {
|
|
"prompt": "Write a story.",
|
|
"option": "default",
|
|
"command": "generate",
|
|
}
|
|
|
|
response = client.post("/api/prose/generate", json=request_data)
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"].startswith("text/event-stream")
|
|
|
|
# Read the streaming response content
|
|
content = b"".join(response.iter_bytes())
|
|
assert b"Generated prose 1" in content or b"Generated prose 2" in content
|
|
|
|
@patch("src.server.app.build_prose_graph")
|
|
def test_generate_prose_error(self, mock_build_graph, client):
|
|
mock_build_graph.side_effect = Exception("Prose generation failed")
|
|
request_data = {
|
|
"prompt": "Write a story.",
|
|
"option": "default",
|
|
"command": "generate",
|
|
}
|
|
response = client.post("/api/prose/generate", json=request_data)
|
|
assert response.status_code == 500
|
|
assert response.json()["detail"] == "Internal Server Error"
|
|
|
|
|
|
class TestCreateInterruptEvent:
|
|
"""Tests for _create_interrupt_event function (Issue #730 fix)."""
|
|
|
|
def test_create_interrupt_event_with_id_attribute(self):
|
|
"""Test that _create_interrupt_event works with LangGraph 1.0+ Interrupt objects that have 'id' attribute."""
|
|
# Create a mock Interrupt object with the new 'id' attribute (LangGraph 1.0+)
|
|
mock_interrupt = MagicMock()
|
|
mock_interrupt.id = "interrupt-123"
|
|
mock_interrupt.value = "Please review the research plan"
|
|
|
|
event_data = {"__interrupt__": [mock_interrupt]}
|
|
thread_id = "thread-456"
|
|
|
|
result = _create_interrupt_event(thread_id, event_data)
|
|
|
|
# Verify the result is a properly formatted SSE event
|
|
assert "event: interrupt\n" in result
|
|
assert '"thread_id": "thread-456"' in result
|
|
assert '"id": "interrupt-123"' in result
|
|
assert '"content": "Please review the research plan"' in result
|
|
assert '"finish_reason": "interrupt"' in result
|
|
assert '"role": "assistant"' in result
|
|
|
|
def test_create_interrupt_event_fallback_to_thread_id(self):
|
|
"""Test that _create_interrupt_event falls back to thread_id when 'id' attribute is None."""
|
|
# Create a mock Interrupt object where id is None
|
|
mock_interrupt = MagicMock()
|
|
mock_interrupt.id = None
|
|
mock_interrupt.value = "Plan review needed"
|
|
|
|
event_data = {"__interrupt__": [mock_interrupt]}
|
|
thread_id = "thread-789"
|
|
|
|
result = _create_interrupt_event(thread_id, event_data)
|
|
|
|
# Verify it falls back to thread_id
|
|
assert '"id": "thread-789"' in result
|
|
assert '"thread_id": "thread-789"' in result
|
|
assert '"content": "Plan review needed"' in result
|
|
|
|
def test_create_interrupt_event_without_id_attribute(self):
|
|
"""Test that _create_interrupt_event handles objects without 'id' attribute (backward compatibility)."""
|
|
# Create a mock object that doesn't have 'id' attribute at all
|
|
class MockInterrupt:
|
|
pass
|
|
mock_interrupt = MockInterrupt()
|
|
mock_interrupt.value = "Waiting for approval"
|
|
|
|
event_data = {"__interrupt__": [mock_interrupt]}
|
|
thread_id = "thread-abc"
|
|
|
|
result = _create_interrupt_event(thread_id, event_data)
|
|
|
|
# Verify it falls back to thread_id when id attribute doesn't exist
|
|
assert '"id": "thread-abc"' in result
|
|
assert '"content": "Waiting for approval"' in result
|
|
|
|
def test_create_interrupt_event_options(self):
|
|
"""Test that _create_interrupt_event includes correct options."""
|
|
mock_interrupt = MagicMock()
|
|
mock_interrupt.id = "int-001"
|
|
mock_interrupt.value = "Review plan"
|
|
|
|
event_data = {"__interrupt__": [mock_interrupt]}
|
|
thread_id = "thread-xyz"
|
|
|
|
result = _create_interrupt_event(thread_id, event_data)
|
|
|
|
# Verify options are included
|
|
assert '"options":' in result
|
|
assert '"text": "Edit plan"' in result
|
|
assert '"value": "edit_plan"' in result
|
|
assert '"text": "Start research"' in result
|
|
assert '"value": "accepted"' in result
|
|
|
|
def test_create_interrupt_event_with_complex_value(self):
|
|
"""Test that _create_interrupt_event handles complex content values."""
|
|
mock_interrupt = MagicMock()
|
|
mock_interrupt.id = "int-complex"
|
|
mock_interrupt.value = {"plan": "Research AI", "steps": ["step1", "step2"]}
|
|
|
|
event_data = {"__interrupt__": [mock_interrupt]}
|
|
thread_id = "thread-complex"
|
|
|
|
result = _create_interrupt_event(thread_id, event_data)
|
|
|
|
# Verify complex value is included (will be serialized as JSON)
|
|
assert '"id": "int-complex"' in result
|
|
assert "Research AI" in result or "plan" in result
|
|
|
|
|
|
class TestLifespanFunction:
|
|
"""Tests for the lifespan function and global connection pool management (Issue #778).
|
|
|
|
These tests verify correct initialization, error handling, and cleanup behavior
|
|
for PostgreSQL and MongoDB global connection pools.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(os.environ, {"LANGGRAPH_CHECKPOINT_SAVER": "false"})
|
|
async def test_lifespan_skips_initialization_when_checkpoint_not_configured(self):
|
|
"""Verify no pool initialization when LANGGRAPH_CHECKPOINT_SAVER=False."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
|
|
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
mock_pg_pool.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{"LANGGRAPH_CHECKPOINT_SAVER": "true", "LANGGRAPH_CHECKPOINT_DB_URL": ""},
|
|
)
|
|
async def test_lifespan_skips_initialization_when_url_empty(self):
|
|
"""Verify no pool initialization when checkpoint URL is empty."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
|
|
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
mock_pg_pool.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
|
"PG_POOL_MIN_SIZE": "2",
|
|
"PG_POOL_MAX_SIZE": "10",
|
|
"PG_POOL_TIMEOUT": "30",
|
|
},
|
|
)
|
|
async def test_lifespan_postgresql_pool_initialization_success(self):
|
|
"""Test successful PostgreSQL connection pool initialization."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
mock_pool = MagicMock()
|
|
mock_pool.open = AsyncMock()
|
|
mock_pool.close = AsyncMock()
|
|
|
|
mock_checkpointer = MagicMock()
|
|
mock_checkpointer.setup = AsyncMock()
|
|
|
|
with (
|
|
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
|
|
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
|
):
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
mock_pool.open.assert_called_once()
|
|
mock_checkpointer.setup.assert_called_once()
|
|
mock_pool.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
|
},
|
|
)
|
|
async def test_lifespan_postgresql_pool_initialization_failure(self):
|
|
"""Verify RuntimeError raised when PostgreSQL pool initialization fails."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
mock_pool = MagicMock()
|
|
mock_pool.open = AsyncMock(
|
|
side_effect=Exception("Connection refused")
|
|
)
|
|
|
|
with patch("src.server.app.AsyncConnectionPool", return_value=mock_pool):
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
assert "PostgreSQL" in str(exc_info.value) or "initialization failed" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
|
"MONGO_MIN_POOL_SIZE": "2",
|
|
"MONGO_MAX_POOL_SIZE": "10",
|
|
},
|
|
)
|
|
async def test_lifespan_mongodb_pool_initialization_success(self):
|
|
"""Test successful MongoDB connection pool initialization."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.close = MagicMock()
|
|
|
|
mock_checkpointer = MagicMock()
|
|
mock_checkpointer.setup = AsyncMock()
|
|
|
|
# Create a mock motor module
|
|
mock_motor_asyncio = MagicMock()
|
|
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
|
|
|
|
with (
|
|
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
|
|
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
|
|
):
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
mock_checkpointer.setup.assert_called_once()
|
|
mock_client.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
|
},
|
|
)
|
|
async def test_lifespan_mongodb_import_error(self):
|
|
"""Verify RuntimeError when motor package is missing."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
|
|
with patch.dict("sys.modules", {"motor": None, "motor.motor_asyncio": None}):
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
assert "motor" in str(exc_info.value).lower() or "MongoDB" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
|
},
|
|
)
|
|
async def test_lifespan_mongodb_connection_failure(self):
|
|
"""Verify RuntimeError on MongoDB connection failure."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
|
|
# Create a mock motor module that raises an exception
|
|
mock_motor_asyncio = MagicMock()
|
|
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(
|
|
side_effect=Exception("Connection refused")
|
|
)
|
|
|
|
with patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}):
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
assert "MongoDB" in str(exc_info.value) or "initialized" in str(exc_info.value)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
|
},
|
|
)
|
|
async def test_lifespan_postgresql_cleanup_on_shutdown(self):
|
|
"""Verify PostgreSQL pool.close() is called during shutdown."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
mock_pool = MagicMock()
|
|
mock_pool.open = AsyncMock()
|
|
mock_pool.close = AsyncMock()
|
|
|
|
mock_checkpointer = MagicMock()
|
|
mock_checkpointer.setup = AsyncMock()
|
|
|
|
with (
|
|
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
|
|
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
|
):
|
|
async with lifespan(mock_app):
|
|
# Verify pool is open during app lifetime
|
|
mock_pool.open.assert_called_once()
|
|
|
|
# Verify pool is closed after context exit
|
|
mock_pool.close.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
|
},
|
|
)
|
|
async def test_lifespan_mongodb_cleanup_on_shutdown(self):
|
|
"""Verify MongoDB client.close() is called during shutdown."""
|
|
from src.server.app import lifespan
|
|
|
|
mock_app = MagicMock()
|
|
mock_client = MagicMock()
|
|
mock_client.close = MagicMock()
|
|
|
|
mock_checkpointer = MagicMock()
|
|
mock_checkpointer.setup = AsyncMock()
|
|
|
|
# Create a mock motor module
|
|
mock_motor_asyncio = MagicMock()
|
|
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
|
|
|
|
with (
|
|
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
|
|
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
|
|
):
|
|
async with lifespan(mock_app):
|
|
pass
|
|
|
|
# Verify client is closed after context exit
|
|
mock_client.close.assert_called_once()
|
|
|
|
|
|
class TestGlobalConnectionPoolUsage:
|
|
"""Tests for _astream_workflow_generator using global connection pools (Issue #778).
|
|
|
|
These tests verify that the workflow generator correctly uses global pools
|
|
when available and falls back to per-request connections when not.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
|
},
|
|
)
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_uses_global_postgresql_pool_when_available(self, mock_graph):
|
|
"""Verify global _pg_checkpointer is used when available."""
|
|
mock_checkpointer = MagicMock()
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
with (
|
|
patch("src.server.app._pg_checkpointer", mock_checkpointer),
|
|
patch("src.server.app._pg_pool", MagicMock()),
|
|
patch("src.server.app._process_initial_messages"),
|
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
|
):
|
|
mock_stream.return_value = self._empty_async_gen()
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
async for _ in generator:
|
|
pass
|
|
|
|
# Verify global checkpointer was assigned to graph
|
|
assert mock_graph.checkpointer == mock_checkpointer
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
|
},
|
|
)
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_falls_back_to_per_request_postgresql(self, mock_graph):
|
|
"""Verify fallback to per-request connection when _pg_checkpointer is None."""
|
|
mock_pool_instance = MagicMock()
|
|
mock_checkpointer = MagicMock()
|
|
mock_checkpointer.setup = AsyncMock()
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
with (
|
|
patch("src.server.app._pg_checkpointer", None),
|
|
patch("src.server.app._pg_pool", None),
|
|
patch("src.server.app._process_initial_messages"),
|
|
patch("src.server.app.AsyncConnectionPool") as mock_pool_class,
|
|
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
|
):
|
|
mock_pool_class.return_value.__aenter__ = AsyncMock(return_value=mock_pool_instance)
|
|
mock_pool_class.return_value.__aexit__ = AsyncMock()
|
|
mock_stream.return_value = self._empty_async_gen()
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
async for _ in generator:
|
|
pass
|
|
|
|
# Verify per-request connection pool was created
|
|
mock_pool_class.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
|
},
|
|
)
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_uses_global_mongodb_pool_when_available(self, mock_graph):
|
|
"""Verify global _mongo_checkpointer is used when available."""
|
|
mock_checkpointer = MagicMock()
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
with (
|
|
patch("src.server.app._mongo_checkpointer", mock_checkpointer),
|
|
patch("src.server.app._mongo_client", MagicMock()),
|
|
patch("src.server.app._process_initial_messages"),
|
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
|
):
|
|
mock_stream.return_value = self._empty_async_gen()
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
async for _ in generator:
|
|
pass
|
|
|
|
# Verify global checkpointer was assigned to graph
|
|
assert mock_graph.checkpointer == mock_checkpointer
|
|
|
|
@pytest.mark.asyncio
|
|
@patch.dict(
|
|
os.environ,
|
|
{
|
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
|
},
|
|
)
|
|
@patch("src.server.app.graph")
|
|
async def test_astream_falls_back_to_per_request_mongodb(self, mock_graph):
|
|
"""Verify fallback to per-request connection when _mongo_checkpointer is None."""
|
|
mock_checkpointer = MagicMock()
|
|
|
|
async def mock_astream(*args, **kwargs):
|
|
yield ("agent1", "step1", {"test": "data"})
|
|
|
|
mock_graph.astream = mock_astream
|
|
|
|
with (
|
|
patch("src.server.app._mongo_checkpointer", None),
|
|
patch("src.server.app._mongo_client", None),
|
|
patch("src.server.app._process_initial_messages"),
|
|
patch("src.server.app.AsyncMongoDBSaver") as mock_saver_class,
|
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
|
):
|
|
mock_saver_class.from_conn_string.return_value.__aenter__ = AsyncMock(
|
|
return_value=mock_checkpointer
|
|
)
|
|
mock_saver_class.from_conn_string.return_value.__aexit__ = AsyncMock()
|
|
mock_stream.return_value = self._empty_async_gen()
|
|
|
|
generator = _astream_workflow_generator(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
thread_id="test_thread",
|
|
resources=[],
|
|
max_plan_iterations=3,
|
|
max_step_num=10,
|
|
max_search_results=5,
|
|
auto_accepted_plan=True,
|
|
interrupt_feedback="",
|
|
mcp_settings={},
|
|
enable_background_investigation=False,
|
|
enable_web_search=True,
|
|
report_style=ReportStyle.ACADEMIC,
|
|
enable_deep_thinking=False,
|
|
enable_clarification=False,
|
|
max_clarification_rounds=3,
|
|
)
|
|
|
|
async for _ in generator:
|
|
pass
|
|
|
|
# Verify per-request MongoDB saver was created
|
|
mock_saver_class.from_conn_string.assert_called_once()
|
|
|
|
async def _empty_async_gen(self):
|
|
"""Helper to create an empty async generator."""
|
|
if False:
|
|
yield
|
|
|
|
|
|
class TestCreateEventStreamMessageThinkTagStripping:
|
|
"""Tests for stripping <think> tags from streamed content (#781).
|
|
|
|
Some models (e.g. DeepSeek-R1, QwQ via ollama) embed reasoning in
|
|
content using <think>...</think> tags instead of the separate
|
|
reasoning_content field.
|
|
"""
|
|
|
|
def _make_mock_chunk(self, content):
|
|
chunk = AIMessageChunk(content=content)
|
|
chunk.id = "msg_test"
|
|
chunk.response_metadata = {}
|
|
return chunk
|
|
|
|
def test_strips_think_tag_at_beginning(self):
|
|
chunk = self._make_mock_chunk(
|
|
"<think>\nLet me analyze...\n</think>\n\n# Report\n\nContent here."
|
|
)
|
|
result = _create_event_stream_message(chunk, {}, "thread-1", "reporter")
|
|
assert "<think>" not in result["content"]
|
|
assert "# Report" in result["content"]
|
|
assert "Content here." in result["content"]
|
|
|
|
def test_strips_multiple_think_blocks(self):
|
|
chunk = self._make_mock_chunk(
|
|
"<think>First thought</think>\nParagraph 1.\n<think>Second thought</think>\nParagraph 2."
|
|
)
|
|
result = _create_event_stream_message(chunk, {}, "thread-1", "coordinator")
|
|
assert "<think>" not in result["content"]
|
|
assert "Paragraph 1." in result["content"]
|
|
assert "Paragraph 2." in result["content"]
|
|
|
|
def test_preserves_content_without_think_tags(self):
|
|
chunk = self._make_mock_chunk("Normal content without think tags.")
|
|
result = _create_event_stream_message(chunk, {}, "thread-1", "planner")
|
|
assert result["content"] == "Normal content without think tags."
|
|
|
|
def test_empty_content_after_stripping(self):
|
|
chunk = self._make_mock_chunk("<think>Only thinking, no real content</think>")
|
|
result = _create_event_stream_message(chunk, {}, "thread-1", "reporter")
|
|
assert "<think>" not in result["content"]
|
|
|
|
def test_preserves_reasoning_content_field(self):
|
|
chunk = self._make_mock_chunk("Actual content")
|
|
chunk.additional_kwargs["reasoning_content"] = "This is reasoning"
|
|
result = _create_event_stream_message(chunk, {}, "thread-1", "planner")
|
|
assert result["content"] == "Actual content"
|
|
assert result["reasoning_content"] == "This is reasoning"
|