test: add unit tests for global connection pool (Issue #778) (#780)

* test: add unit tests for global connection pool (Issue #778)

- Add TestLifespanFunction class with 9 tests for lifespan management:
  - PostgreSQL/MongoDB pool initialization success/failure
  - Cleanup on shutdown
  - Skip initialization when not configured

- Add TestGlobalConnectionPoolUsage class with 4 tests:
  - Using global pools when available
  - Fallback to per-request connections

- Fix missing dict_row import in app.py (bug from PR #757)

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Willem Jiang
2025-12-23 23:06:39 +08:00
committed by GitHub
parent 83e9d7c9e5
commit fb319aaa44
2 changed files with 472 additions and 1 deletions

View File

@@ -17,6 +17,7 @@ from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
from langgraph.types import Command from langgraph.types import Command
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
from src.config.configuration import get_recursion_limit from src.config.configuration import get_recursion_limit

View File

@@ -1,9 +1,10 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import base64 import base64
import os import os
from unittest.mock import MagicMock, mock_open, patch from unittest.mock import AsyncMock, MagicMock, mock_open, patch
import pytest import pytest
from fastapi import HTTPException from fastapi import HTTPException
@@ -1129,3 +1130,472 @@ class TestCreateInterruptEvent:
# Verify complex value is included (will be serialized as JSON) # Verify complex value is included (will be serialized as JSON)
assert '"id": "int-complex"' in result assert '"id": "int-complex"' in result
assert "Research AI" in result or "plan" in result assert "Research AI" in result or "plan" in result
class TestLifespanFunction:
"""Tests for the lifespan function and global connection pool management (Issue #778).
These tests verify correct initialization, error handling, and cleanup behavior
for PostgreSQL and MongoDB global connection pools.
"""
@pytest.mark.asyncio
@patch.dict(os.environ, {"LANGGRAPH_CHECKPOINT_SAVER": "false"})
async def test_lifespan_skips_initialization_when_checkpoint_not_configured(self):
"""Verify no pool initialization when LANGGRAPH_CHECKPOINT_SAVER=False."""
from src.server.app import lifespan
mock_app = MagicMock()
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
async with lifespan(mock_app):
pass
mock_pg_pool.assert_not_called()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{"LANGGRAPH_CHECKPOINT_SAVER": "true", "LANGGRAPH_CHECKPOINT_DB_URL": ""},
)
async def test_lifespan_skips_initialization_when_url_empty(self):
"""Verify no pool initialization when checkpoint URL is empty."""
from src.server.app import lifespan
mock_app = MagicMock()
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
async with lifespan(mock_app):
pass
mock_pg_pool.assert_not_called()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
"PG_POOL_MIN_SIZE": "2",
"PG_POOL_MAX_SIZE": "10",
"PG_POOL_TIMEOUT": "30",
},
)
async def test_lifespan_postgresql_pool_initialization_success(self):
"""Test successful PostgreSQL connection pool initialization."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_pool = MagicMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
with (
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
pass
mock_pool.open.assert_called_once()
mock_checkpointer.setup.assert_called_once()
mock_pool.close.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
async def test_lifespan_postgresql_pool_initialization_failure(self):
"""Verify RuntimeError raised when PostgreSQL pool initialization fails."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_pool = MagicMock()
mock_pool.open = AsyncMock(
side_effect=Exception("Connection refused")
)
with patch("src.server.app.AsyncConnectionPool", return_value=mock_pool):
with pytest.raises(RuntimeError) as exc_info:
async with lifespan(mock_app):
pass
assert "PostgreSQL" in str(exc_info.value) or "initialization failed" in str(exc_info.value)
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
"MONGO_MIN_POOL_SIZE": "2",
"MONGO_MAX_POOL_SIZE": "10",
},
)
async def test_lifespan_mongodb_pool_initialization_success(self):
"""Test successful MongoDB connection pool initialization."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_client = MagicMock()
mock_client.close = MagicMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
# Create a mock motor module
mock_motor_asyncio = MagicMock()
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
with (
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
pass
mock_checkpointer.setup.assert_called_once()
mock_client.close.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
async def test_lifespan_mongodb_import_error(self):
"""Verify RuntimeError when motor package is missing."""
from src.server.app import lifespan
mock_app = MagicMock()
with patch.dict("sys.modules", {"motor": None, "motor.motor_asyncio": None}):
with pytest.raises(RuntimeError) as exc_info:
async with lifespan(mock_app):
pass
assert "motor" in str(exc_info.value).lower() or "MongoDB" in str(exc_info.value)
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
async def test_lifespan_mongodb_connection_failure(self):
"""Verify RuntimeError on MongoDB connection failure."""
from src.server.app import lifespan
mock_app = MagicMock()
# Create a mock motor module that raises an exception
mock_motor_asyncio = MagicMock()
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(
side_effect=Exception("Connection refused")
)
with patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}):
with pytest.raises(RuntimeError) as exc_info:
async with lifespan(mock_app):
pass
assert "MongoDB" in str(exc_info.value) or "initialized" in str(exc_info.value)
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
async def test_lifespan_postgresql_cleanup_on_shutdown(self):
"""Verify PostgreSQL pool.close() is called during shutdown."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_pool = MagicMock()
mock_pool.open = AsyncMock()
mock_pool.close = AsyncMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
with (
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
# Verify pool is open during app lifetime
mock_pool.open.assert_called_once()
# Verify pool is closed after context exit
mock_pool.close.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
async def test_lifespan_mongodb_cleanup_on_shutdown(self):
"""Verify MongoDB client.close() is called during shutdown."""
from src.server.app import lifespan
mock_app = MagicMock()
mock_client = MagicMock()
mock_client.close = MagicMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
# Create a mock motor module
mock_motor_asyncio = MagicMock()
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
with (
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
):
async with lifespan(mock_app):
pass
# Verify client is closed after context exit
mock_client.close.assert_called_once()
class TestGlobalConnectionPoolUsage:
"""Tests for _astream_workflow_generator using global connection pools (Issue #778).
These tests verify that the workflow generator correctly uses global pools
when available and falls back to per-request connections when not.
"""
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
@patch("src.server.app.graph")
async def test_astream_uses_global_postgresql_pool_when_available(self, mock_graph):
"""Verify global _pg_checkpointer is used when available."""
mock_checkpointer = MagicMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._pg_checkpointer", mock_checkpointer),
patch("src.server.app._pg_pool", MagicMock()),
patch("src.server.app._process_initial_messages"),
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify global checkpointer was assigned to graph
assert mock_graph.checkpointer == mock_checkpointer
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
},
)
@patch("src.server.app.graph")
async def test_astream_falls_back_to_per_request_postgresql(self, mock_graph):
"""Verify fallback to per-request connection when _pg_checkpointer is None."""
mock_pool_instance = MagicMock()
mock_checkpointer = MagicMock()
mock_checkpointer.setup = AsyncMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._pg_checkpointer", None),
patch("src.server.app._pg_pool", None),
patch("src.server.app._process_initial_messages"),
patch("src.server.app.AsyncConnectionPool") as mock_pool_class,
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_pool_class.return_value.__aenter__ = AsyncMock(return_value=mock_pool_instance)
mock_pool_class.return_value.__aexit__ = AsyncMock()
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify per-request connection pool was created
mock_pool_class.assert_called_once()
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
@patch("src.server.app.graph")
async def test_astream_uses_global_mongodb_pool_when_available(self, mock_graph):
"""Verify global _mongo_checkpointer is used when available."""
mock_checkpointer = MagicMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._mongo_checkpointer", mock_checkpointer),
patch("src.server.app._mongo_client", MagicMock()),
patch("src.server.app._process_initial_messages"),
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify global checkpointer was assigned to graph
assert mock_graph.checkpointer == mock_checkpointer
@pytest.mark.asyncio
@patch.dict(
os.environ,
{
"LANGGRAPH_CHECKPOINT_SAVER": "true",
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
},
)
@patch("src.server.app.graph")
async def test_astream_falls_back_to_per_request_mongodb(self, mock_graph):
"""Verify fallback to per-request connection when _mongo_checkpointer is None."""
mock_checkpointer = MagicMock()
async def mock_astream(*args, **kwargs):
yield ("agent1", "step1", {"test": "data"})
mock_graph.astream = mock_astream
with (
patch("src.server.app._mongo_checkpointer", None),
patch("src.server.app._mongo_client", None),
patch("src.server.app._process_initial_messages"),
patch("src.server.app.AsyncMongoDBSaver") as mock_saver_class,
patch("src.server.app._stream_graph_events") as mock_stream,
):
mock_saver_class.from_conn_string.return_value.__aenter__ = AsyncMock(
return_value=mock_checkpointer
)
mock_saver_class.from_conn_string.return_value.__aexit__ = AsyncMock()
mock_stream.return_value = self._empty_async_gen()
generator = _astream_workflow_generator(
messages=[{"role": "user", "content": "Hello"}],
thread_id="test_thread",
resources=[],
max_plan_iterations=3,
max_step_num=10,
max_search_results=5,
auto_accepted_plan=True,
interrupt_feedback="",
mcp_settings={},
enable_background_investigation=False,
enable_web_search=True,
report_style=ReportStyle.ACADEMIC,
enable_deep_thinking=False,
enable_clarification=False,
max_clarification_rounds=3,
)
async for _ in generator:
pass
# Verify per-request MongoDB saver was created
mock_saver_class.from_conn_string.assert_called_once()
async def _empty_async_gen(self):
"""Helper to create an empty async generator."""
if False:
yield