test: add unit tests for global connection pool (Issue #778) (#780)

* test: add unit tests for global connection pool (Issue #778)

- Add TestLifespanFunction class with 9 tests for lifespan management:
  - PostgreSQL/MongoDB pool initialization success/failure
  - Cleanup on shutdown
  - Skip initialization when not configured

- Add TestGlobalConnectionPoolUsage class with 4 tests:
  - Using global pools when available
  - Fallback to per-request connections

- Fix missing dict_row import in app.py (bug from PR #757)

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Willem Jiang
2025-12-23 23:06:39 +08:00
committed by GitHub
parent 83e9d7c9e5
commit fb319aaa44
2 changed files with 472 additions and 1 deletions

View File

@@ -1,9 +1,10 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import base64
import os
from unittest.mock import MagicMock, mock_open, patch
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import pytest
from fastapi import HTTPException
@@ -1129,3 +1130,472 @@ class TestCreateInterruptEvent:
# Verify complex value is included (will be serialized as JSON)
assert '"id": "int-complex"' in result
assert "Research AI" in result or "plan" in result
class TestLifespanFunction:
"""Tests for the lifespan function and global connection pool management (Issue #778).
These tests verify correct initialization, error handling, and cleanup behavior
for PostgreSQL and MongoDB global connection pools.
"""
@pytest.mark.asyncio
@patch.dict(os.environ, {"LANGGRAPH_CHECKPOINT_SAVER": "false"})
async def test_lifespan_skips_initialization_when_checkpoint_not_configured(self):
"""Verify no pool initialization when LANGGRAPH_CHECKPOINT_SAVER=False."""
from src.server.app import lifespan
mock_app = MagicMock()
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
async with lifespan(mock_app):
pass
mock_pg_pool.assert_not_called()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{"LANGGRAPH_CHECKPOINT_SAVER": "true", "LANGGRAPH_CHECKPOINT_DB_URL": ""},
)
async def test_lifespan_skips_initialization_when_url_empty(self):
"""Verify no pool initialization when checkpoint URL is empty."""
from src.server.app import lifespan
mock_app = MagicMock()
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
async with lifespan(mock_app):
pass
mock_pg_pool.assert_not_called()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
"PG_POOL_MIN_SIZE": "2",
"PG_POOL_MAX_SIZE": "10",
"PG_POOL_TIMEOUT": "30",
},
)
async def test_lifespan_postgresql_pool_initialization_success(self):
"""Test successful PostgreSQL connection pool initialization."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_pool = MagicMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
with (
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
pass
mock_pool.open.assert_called_once()
mock_checkpointer.setup.assert_called_once()
mock_pool.close.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
async def test_lifespan_postgresql_pool_initialization_failure(self):
"""Verify RuntimeError raised when PostgreSQL pool initialization fails."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_pool = MagicMock()
mock_pool.open = AsyncMock(
side_effect=Exception("Connection refused")
)
with patch("src.server.app.AsyncConnectionPool", return_value=mock_pool):
with pytest.raises(RuntimeError) as exc_info:
async with lifespan(mock_app):
pass
assert "PostgreSQL" in str(exc_info.value) or "initialization failed" in str(exc_info.value)
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
"MONGO_MIN_POOL_SIZE": "2",
"MONGO_MAX_POOL_SIZE": "10",
},
)
async def test_lifespan_mongodb_pool_initialization_success(self):
"""Test successful MongoDB connection pool initialization."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_client = MagicMock()
mock_client.close = MagicMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
# Create a mock motor module
mock_motor_asyncio = MagicMock()
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
with (
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
pass
mock_checkpointer.setup.assert_called_once()
mock_client.close.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
async def test_lifespan_mongodb_import_error(self):
"""Verify RuntimeError when motor package is missing."""
from src.server.app import lifespan
mock_app = MagicMock()
with patch.dict("sys.modules", {"motor": None, "motor.motor_asyncio": None}):
with pytest.raises(RuntimeError) as exc_info:
async with lifespan(mock_app):
pass
assert "motor" in str(exc_info.value).lower() or "MongoDB" in str(exc_info.value)
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
async def test_lifespan_mongodb_connection_failure(self):
"""Verify RuntimeError on MongoDB connection failure."""
from src.server.app import lifespan
mock_app = MagicMock()
# Create a mock motor module that raises an exception
mock_motor_asyncio = MagicMock()
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(
side_effect=Exception("Connection refused")
)
with patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}):
with pytest.raises(RuntimeError) as exc_info:
async with lifespan(mock_app):
pass
assert "MongoDB" in str(exc_info.value) or "initialized" in str(exc_info.value)
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
async def test_lifespan_postgresql_cleanup_on_shutdown(self):
"""Verify PostgreSQL pool.close() is called during shutdown."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_pool = MagicMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
with (
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
# Verify pool is open during app lifetime
mock_pool.open.assert_called_once()
# Verify pool is closed after context exit
mock_pool.close.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
async def test_lifespan_mongodb_cleanup_on_shutdown(self):
"""Verify MongoDB client.close() is called during shutdown."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_client = MagicMock()
mock_client.close = MagicMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
# Create a mock motor module
mock_motor_asyncio = MagicMock()
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
with (
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
pass
# Verify client is closed after context exit
mock_client.close.assert_called_once()
class TestGlobalConnectionPoolUsage:
"""Tests for _astream_workflow_generator using global connection pools (Issue #778).
These tests verify that the workflow generator correctly uses global pools
when available and falls back to per-request connections when not.
"""
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
@patch("src.server.app.graph")
async def test_astream_uses_global_postgresql_pool_when_available(self, mock_graph):
"""Verify global _pg_checkpointer is used when available."""
mock_checkpointer = MagicMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._pg_checkpointer", mock_checkpointer),
patch("src.server.app._pg_pool", MagicMock()),
patch("src.server.app._process_initial_messages"),
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify global checkpointer was assigned to graph
assert mock_graph.checkpointer == mock_checkpointer
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
@patch("src.server.app.graph")
async def test_astream_falls_back_to_per_request_postgresql(self, mock_graph):
"""Verify fallback to per-request connection when _pg_checkpointer is None."""
mock_pool_instance = MagicMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._pg_checkpointer", None),
patch("src.server.app._pg_pool", None),
patch("src.server.app._process_initial_messages"),
patch("src.server.app.AsyncConnectionPool") as mock_pool_class,
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_pool_class.return_value.__aenter__ = AsyncMock(return_value=mock_pool_instance)
mock_pool_class.return_value.__aexit__ = AsyncMock()
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify per-request connection pool was created
mock_pool_class.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
@patch("src.server.app.graph")
async def test_astream_uses_global_mongodb_pool_when_available(self, mock_graph):
"""Verify global _mongo_checkpointer is used when available."""
mock_checkpointer = MagicMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._mongo_checkpointer", mock_checkpointer),
patch("src.server.app._mongo_client", MagicMock()),
patch("src.server.app._process_initial_messages"),
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify global checkpointer was assigned to graph
assert mock_graph.checkpointer == mock_checkpointer
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
@patch("src.server.app.graph")
async def test_astream_falls_back_to_per_request_mongodb(self, mock_graph):
"""Verify fallback to per-request connection when _mongo_checkpointer is None."""
mock_checkpointer = MagicMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._mongo_checkpointer", None),
patch("src.server.app._mongo_client", None),
patch("src.server.app._process_initial_messages"),
patch("src.server.app.AsyncMongoDBSaver") as mock_saver_class,
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_saver_class.from_conn_string.return_value.__aenter__ = AsyncMock(
return_value=mock_checkpointer
)
mock_saver_class.from_conn_string.return_value.__aexit__ = AsyncMock()
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify per-request MongoDB saver was created
mock_saver_class.from_conn_string.assert_called_once()
async def _empty_async_gen(self):
"""Helper to create an empty async generator."""
if False:
yield