mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
* test: add unit tests for global connection pool (Issue #778) - Add TestLifespanFunction class with 9 tests for lifespan management: - PostgreSQL/MongoDB pool initialization success/failure - Cleanup on shutdown - Skip initialization when not configured - Add TestGlobalConnectionPoolUsage class with 4 tests: - Using global pools when available - Fallback to per-request connections - Fix missing dict_row import in app.py (bug from PR #757) * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
import base64
|
||||
import os
|
||||
from unittest.mock import MagicMock, mock_open, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
@@ -1129,3 +1130,472 @@ class TestCreateInterruptEvent:
|
||||
# Verify complex value is included (will be serialized as JSON)
|
||||
assert '"id": "int-complex"' in result
|
||||
assert "Research AI" in result or "plan" in result
|
||||
|
||||
|
||||
class TestLifespanFunction:
|
||||
"""Tests for the lifespan function and global connection pool management (Issue #778).
|
||||
|
||||
These tests verify correct initialization, error handling, and cleanup behavior
|
||||
for PostgreSQL and MongoDB global connection pools.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(os.environ, {"LANGGRAPH_CHECKPOINT_SAVER": "false"})
|
||||
async def test_lifespan_skips_initialization_when_checkpoint_not_configured(self):
|
||||
"""Verify no pool initialization when LANGGRAPH_CHECKPOINT_SAVER=False."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
mock_pg_pool.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{"LANGGRAPH_CHECKPOINT_SAVER": "true", "LANGGRAPH_CHECKPOINT_DB_URL": ""},
|
||||
)
|
||||
async def test_lifespan_skips_initialization_when_url_empty(self):
|
||||
"""Verify no pool initialization when checkpoint URL is empty."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
mock_pg_pool.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||
"PG_POOL_MIN_SIZE": "2",
|
||||
"PG_POOL_MAX_SIZE": "10",
|
||||
"PG_POOL_TIMEOUT": "30",
|
||||
},
|
||||
)
|
||||
async def test_lifespan_postgresql_pool_initialization_success(self):
|
||||
"""Test successful PostgreSQL connection pool initialization."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.open = AsyncMock()
|
||||
mock_pool.close = AsyncMock()
|
||||
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.setup = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
|
||||
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
||||
):
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
mock_pool.open.assert_called_once()
|
||||
mock_checkpointer.setup.assert_called_once()
|
||||
mock_pool.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||
},
|
||||
)
|
||||
async def test_lifespan_postgresql_pool_initialization_failure(self):
|
||||
"""Verify RuntimeError raised when PostgreSQL pool initialization fails."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.open = AsyncMock(
|
||||
side_effect=Exception("Connection refused")
|
||||
)
|
||||
|
||||
with patch("src.server.app.AsyncConnectionPool", return_value=mock_pool):
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
assert "PostgreSQL" in str(exc_info.value) or "initialization failed" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||
"MONGO_MIN_POOL_SIZE": "2",
|
||||
"MONGO_MAX_POOL_SIZE": "10",
|
||||
},
|
||||
)
|
||||
async def test_lifespan_mongodb_pool_initialization_success(self):
|
||||
"""Test successful MongoDB connection pool initialization."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.close = MagicMock()
|
||||
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.setup = AsyncMock()
|
||||
|
||||
# Create a mock motor module
|
||||
mock_motor_asyncio = MagicMock()
|
||||
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
|
||||
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
|
||||
):
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
mock_checkpointer.setup.assert_called_once()
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||
},
|
||||
)
|
||||
async def test_lifespan_mongodb_import_error(self):
|
||||
"""Verify RuntimeError when motor package is missing."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with patch.dict("sys.modules", {"motor": None, "motor.motor_asyncio": None}):
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
assert "motor" in str(exc_info.value).lower() or "MongoDB" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||
},
|
||||
)
|
||||
async def test_lifespan_mongodb_connection_failure(self):
|
||||
"""Verify RuntimeError on MongoDB connection failure."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
# Create a mock motor module that raises an exception
|
||||
mock_motor_asyncio = MagicMock()
|
||||
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(
|
||||
side_effect=Exception("Connection refused")
|
||||
)
|
||||
|
||||
with patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}):
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
assert "MongoDB" in str(exc_info.value) or "initialized" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||
},
|
||||
)
|
||||
async def test_lifespan_postgresql_cleanup_on_shutdown(self):
|
||||
"""Verify PostgreSQL pool.close() is called during shutdown."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.open = AsyncMock()
|
||||
mock_pool.close = AsyncMock()
|
||||
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.setup = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
|
||||
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
||||
):
|
||||
async with lifespan(mock_app):
|
||||
# Verify pool is open during app lifetime
|
||||
mock_pool.open.assert_called_once()
|
||||
|
||||
# Verify pool is closed after context exit
|
||||
mock_pool.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||
},
|
||||
)
|
||||
async def test_lifespan_mongodb_cleanup_on_shutdown(self):
|
||||
"""Verify MongoDB client.close() is called during shutdown."""
|
||||
from src.server.app import lifespan
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.close = MagicMock()
|
||||
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.setup = AsyncMock()
|
||||
|
||||
# Create a mock motor module
|
||||
mock_motor_asyncio = MagicMock()
|
||||
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
|
||||
|
||||
with (
|
||||
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
|
||||
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
|
||||
):
|
||||
async with lifespan(mock_app):
|
||||
pass
|
||||
|
||||
# Verify client is closed after context exit
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
|
||||
class TestGlobalConnectionPoolUsage:
|
||||
"""Tests for _astream_workflow_generator using global connection pools (Issue #778).
|
||||
|
||||
These tests verify that the workflow generator correctly uses global pools
|
||||
when available and falls back to per-request connections when not.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||
},
|
||||
)
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_uses_global_postgresql_pool_when_available(self, mock_graph):
|
||||
"""Verify global _pg_checkpointer is used when available."""
|
||||
mock_checkpointer = MagicMock()
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
yield ("agent1", "step1", {"test": "data"})
|
||||
|
||||
mock_graph.astream = mock_astream
|
||||
|
||||
with (
|
||||
patch("src.server.app._pg_checkpointer", mock_checkpointer),
|
||||
patch("src.server.app._pg_pool", MagicMock()),
|
||||
patch("src.server.app._process_initial_messages"),
|
||||
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||
):
|
||||
mock_stream.return_value = self._empty_async_gen()
|
||||
|
||||
generator = _astream_workflow_generator(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
thread_id="test_thread",
|
||||
resources=[],
|
||||
max_plan_iterations=3,
|
||||
max_step_num=10,
|
||||
max_search_results=5,
|
||||
auto_accepted_plan=True,
|
||||
interrupt_feedback="",
|
||||
mcp_settings={},
|
||||
enable_background_investigation=False,
|
||||
enable_web_search=True,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
async for _ in generator:
|
||||
pass
|
||||
|
||||
# Verify global checkpointer was assigned to graph
|
||||
assert mock_graph.checkpointer == mock_checkpointer
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||
},
|
||||
)
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_falls_back_to_per_request_postgresql(self, mock_graph):
|
||||
"""Verify fallback to per-request connection when _pg_checkpointer is None."""
|
||||
mock_pool_instance = MagicMock()
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_checkpointer.setup = AsyncMock()
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
yield ("agent1", "step1", {"test": "data"})
|
||||
|
||||
mock_graph.astream = mock_astream
|
||||
|
||||
with (
|
||||
patch("src.server.app._pg_checkpointer", None),
|
||||
patch("src.server.app._pg_pool", None),
|
||||
patch("src.server.app._process_initial_messages"),
|
||||
patch("src.server.app.AsyncConnectionPool") as mock_pool_class,
|
||||
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
||||
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||
):
|
||||
mock_pool_class.return_value.__aenter__ = AsyncMock(return_value=mock_pool_instance)
|
||||
mock_pool_class.return_value.__aexit__ = AsyncMock()
|
||||
mock_stream.return_value = self._empty_async_gen()
|
||||
|
||||
generator = _astream_workflow_generator(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
thread_id="test_thread",
|
||||
resources=[],
|
||||
max_plan_iterations=3,
|
||||
max_step_num=10,
|
||||
max_search_results=5,
|
||||
auto_accepted_plan=True,
|
||||
interrupt_feedback="",
|
||||
mcp_settings={},
|
||||
enable_background_investigation=False,
|
||||
enable_web_search=True,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
async for _ in generator:
|
||||
pass
|
||||
|
||||
# Verify per-request connection pool was created
|
||||
mock_pool_class.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||
},
|
||||
)
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_uses_global_mongodb_pool_when_available(self, mock_graph):
|
||||
"""Verify global _mongo_checkpointer is used when available."""
|
||||
mock_checkpointer = MagicMock()
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
yield ("agent1", "step1", {"test": "data"})
|
||||
|
||||
mock_graph.astream = mock_astream
|
||||
|
||||
with (
|
||||
patch("src.server.app._mongo_checkpointer", mock_checkpointer),
|
||||
patch("src.server.app._mongo_client", MagicMock()),
|
||||
patch("src.server.app._process_initial_messages"),
|
||||
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||
):
|
||||
mock_stream.return_value = self._empty_async_gen()
|
||||
|
||||
generator = _astream_workflow_generator(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
thread_id="test_thread",
|
||||
resources=[],
|
||||
max_plan_iterations=3,
|
||||
max_step_num=10,
|
||||
max_search_results=5,
|
||||
auto_accepted_plan=True,
|
||||
interrupt_feedback="",
|
||||
mcp_settings={},
|
||||
enable_background_investigation=False,
|
||||
enable_web_search=True,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
async for _ in generator:
|
||||
pass
|
||||
|
||||
# Verify global checkpointer was assigned to graph
|
||||
assert mock_graph.checkpointer == mock_checkpointer
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||
},
|
||||
)
|
||||
@patch("src.server.app.graph")
|
||||
async def test_astream_falls_back_to_per_request_mongodb(self, mock_graph):
|
||||
"""Verify fallback to per-request connection when _mongo_checkpointer is None."""
|
||||
mock_checkpointer = MagicMock()
|
||||
|
||||
async def mock_astream(*args, **kwargs):
|
||||
yield ("agent1", "step1", {"test": "data"})
|
||||
|
||||
mock_graph.astream = mock_astream
|
||||
|
||||
with (
|
||||
patch("src.server.app._mongo_checkpointer", None),
|
||||
patch("src.server.app._mongo_client", None),
|
||||
patch("src.server.app._process_initial_messages"),
|
||||
patch("src.server.app.AsyncMongoDBSaver") as mock_saver_class,
|
||||
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||
):
|
||||
mock_saver_class.from_conn_string.return_value.__aenter__ = AsyncMock(
|
||||
return_value=mock_checkpointer
|
||||
)
|
||||
mock_saver_class.from_conn_string.return_value.__aexit__ = AsyncMock()
|
||||
mock_stream.return_value = self._empty_async_gen()
|
||||
|
||||
generator = _astream_workflow_generator(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
thread_id="test_thread",
|
||||
resources=[],
|
||||
max_plan_iterations=3,
|
||||
max_step_num=10,
|
||||
max_search_results=5,
|
||||
auto_accepted_plan=True,
|
||||
interrupt_feedback="",
|
||||
mcp_settings={},
|
||||
enable_background_investigation=False,
|
||||
enable_web_search=True,
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
enable_deep_thinking=False,
|
||||
enable_clarification=False,
|
||||
max_clarification_rounds=3,
|
||||
)
|
||||
|
||||
async for _ in generator:
|
||||
pass
|
||||
|
||||
# Verify per-request MongoDB saver was created
|
||||
mock_saver_class.from_conn_string.assert_called_once()
|
||||
|
||||
async def _empty_async_gen(self):
|
||||
"""Helper to create an empty async generator."""
|
||||
if False:
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user