From fb319aaa44b212e3d43f691c943b9e0ed33e1162 Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Tue, 23 Dec 2025 23:06:39 +0800 Subject: [PATCH] test: add unit tests for global connection pool (Issue #778) (#780) * test: add unit tests for global connection pool (Issue #778) - Add TestLifespanFunction class with 9 tests for lifespan management: - PostgreSQL/MongoDB pool initialization success/failure - Cleanup on shutdown - Skip initialization when not configured - Add TestGlobalConnectionPoolUsage class with 4 tests: - Using global pools when available - Fallback to per-request connections - Fix missing dict_row import in app.py (bug from PR #757) * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/server/app.py | 1 + tests/unit/server/test_app.py | 472 +++++++++++++++++++++++++++++++++- 2 files changed, 472 insertions(+), 1 deletion(-) diff --git a/src/server/app.py b/src/server/app.py index e68be37..951d318 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -17,6 +17,7 @@ from langgraph.checkpoint.mongodb import AsyncMongoDBSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.store.memory import InMemoryStore from langgraph.types import Command +from psycopg.rows import dict_row from psycopg_pool import AsyncConnectionPool from src.config.configuration import get_recursion_limit diff --git a/tests/unit/server/test_app.py b/tests/unit/server/test_app.py index 981c48b..94b22c1 100644 --- a/tests/unit/server/test_app.py +++ b/tests/unit/server/test_app.py @@ -1,9 +1,10 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT + import base64 import os -from unittest.mock import MagicMock, mock_open, patch +from unittest.mock import AsyncMock, MagicMock, mock_open, patch import pytest from fastapi import HTTPException @@ -1129,3 +1130,472 @@ class TestCreateInterruptEvent: # Verify complex value is included (will be serialized as JSON) assert '"id": "int-complex"' in result assert "Research AI" in result or "plan" in result + + +class TestLifespanFunction: + """Tests for the lifespan function and global connection pool management (Issue #778). + + These tests verify correct initialization, error handling, and cleanup behavior + for PostgreSQL and MongoDB global connection pools. + """ + + @pytest.mark.asyncio + @patch.dict(os.environ, {"LANGGRAPH_CHECKPOINT_SAVER": "false"}) + async def test_lifespan_skips_initialization_when_checkpoint_not_configured(self): + """Verify no pool initialization when LANGGRAPH_CHECKPOINT_SAVER=False.""" + from src.server.app import lifespan + + mock_app = MagicMock() + + with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool: + async with lifespan(mock_app): + pass + + mock_pg_pool.assert_not_called() + + @pytest.mark.asyncio + @patch.dict( + os.environ, + {"LANGGRAPH_CHECKPOINT_SAVER": "true", "LANGGRAPH_CHECKPOINT_DB_URL": ""}, + ) + async def test_lifespan_skips_initialization_when_url_empty(self): + """Verify no pool initialization when checkpoint URL is empty.""" + from src.server.app import lifespan + + mock_app = MagicMock() + + with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool: + async with lifespan(mock_app): + pass + + mock_pg_pool.assert_not_called() + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test", + "PG_POOL_MIN_SIZE": "2", + "PG_POOL_MAX_SIZE": "10", + "PG_POOL_TIMEOUT": "30", + }, + ) + async def test_lifespan_postgresql_pool_initialization_success(self): + """Test successful PostgreSQL connection pool initialization.""" + from src.server.app import lifespan + + mock_app = MagicMock() + mock_pool = MagicMock() + mock_pool.open = AsyncMock() + mock_pool.close = AsyncMock() + + mock_checkpointer = MagicMock() + mock_checkpointer.setup = AsyncMock() + + with ( + patch("src.server.app.AsyncConnectionPool", return_value=mock_pool), + patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer), + ): + async with lifespan(mock_app): + pass + + mock_pool.open.assert_called_once() + mock_checkpointer.setup.assert_called_once() + mock_pool.close.assert_called_once() + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test", + }, + ) + async def test_lifespan_postgresql_pool_initialization_failure(self): + """Verify RuntimeError raised when PostgreSQL pool initialization fails.""" + from src.server.app import lifespan + + mock_app = MagicMock() + mock_pool = MagicMock() + mock_pool.open = AsyncMock( + side_effect=Exception("Connection refused") + ) + + with patch("src.server.app.AsyncConnectionPool", return_value=mock_pool): + with pytest.raises(RuntimeError) as exc_info: + async with lifespan(mock_app): + pass + + assert "PostgreSQL" in str(exc_info.value) or "initialization failed" in str(exc_info.value) + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test", + "MONGO_MIN_POOL_SIZE": "2", + "MONGO_MAX_POOL_SIZE": "10", + }, + ) + async def test_lifespan_mongodb_pool_initialization_success(self): + """Test successful MongoDB connection pool initialization.""" + from src.server.app import lifespan + + mock_app = MagicMock() + mock_client = MagicMock() + mock_client.close = MagicMock() + + mock_checkpointer = MagicMock() + mock_checkpointer.setup = AsyncMock() + + # Create a mock motor module + mock_motor_asyncio = MagicMock() + mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client) + + with ( + patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}), + patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer), + ): + async with lifespan(mock_app): + pass + + mock_checkpointer.setup.assert_called_once() + mock_client.close.assert_called_once() + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test", + }, + ) + async def test_lifespan_mongodb_import_error(self): + """Verify RuntimeError when motor package is missing.""" + from src.server.app import lifespan + + mock_app = MagicMock() + + with patch.dict("sys.modules", {"motor": None, "motor.motor_asyncio": None}): + with pytest.raises(RuntimeError) as exc_info: + async with lifespan(mock_app): + pass + + assert "motor" in str(exc_info.value).lower() or "MongoDB" in str(exc_info.value) + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test", + }, + ) + async def test_lifespan_mongodb_connection_failure(self): + """Verify RuntimeError on MongoDB connection failure.""" + from src.server.app import lifespan + + mock_app = MagicMock() + + # Create a mock motor module that raises an exception + mock_motor_asyncio = MagicMock() + mock_motor_asyncio.AsyncIOMotorClient = MagicMock( + side_effect=Exception("Connection refused") + ) + + with patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}): + with pytest.raises(RuntimeError) as exc_info: + async with lifespan(mock_app): + pass + + assert "MongoDB" in str(exc_info.value) or "initialized" in str(exc_info.value) + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test", + }, + ) + async def test_lifespan_postgresql_cleanup_on_shutdown(self): + """Verify PostgreSQL pool.close() is called during shutdown.""" + from src.server.app import lifespan + + mock_app = MagicMock() + mock_pool = MagicMock() + mock_pool.open = AsyncMock() + mock_pool.close = AsyncMock() + + mock_checkpointer = MagicMock() + mock_checkpointer.setup = AsyncMock() + + with ( + patch("src.server.app.AsyncConnectionPool", return_value=mock_pool), + patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer), + ): + async with lifespan(mock_app): + # Verify pool is open during app lifetime + mock_pool.open.assert_called_once() + + # Verify pool is closed after context exit + mock_pool.close.assert_called_once() + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test", + }, + ) + async def test_lifespan_mongodb_cleanup_on_shutdown(self): + """Verify MongoDB client.close() is called during shutdown.""" + from src.server.app import lifespan + + mock_app = MagicMock() + mock_client = MagicMock() + mock_client.close = MagicMock() + + mock_checkpointer = MagicMock() + mock_checkpointer.setup = AsyncMock() + + # Create a mock motor module + mock_motor_asyncio = MagicMock() + mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client) + + with ( + patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}), + patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer), + ): + async with lifespan(mock_app): + pass + + # Verify client is closed after context exit + mock_client.close.assert_called_once() + + +class TestGlobalConnectionPoolUsage: + """Tests for _astream_workflow_generator using global connection pools (Issue #778). + + These tests verify that the workflow generator correctly uses global pools + when available and falls back to per-request connections when not. + """ + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test", + }, + ) + @patch("src.server.app.graph") + async def test_astream_uses_global_postgresql_pool_when_available(self, mock_graph): + """Verify global _pg_checkpointer is used when available.""" + mock_checkpointer = MagicMock() + + async def mock_astream(*args, **kwargs): + yield ("agent1", "step1", {"test": "data"}) + + mock_graph.astream = mock_astream + + with ( + patch("src.server.app._pg_checkpointer", mock_checkpointer), + patch("src.server.app._pg_pool", MagicMock()), + patch("src.server.app._process_initial_messages"), + patch("src.server.app._stream_graph_events") as mock_stream, + ): + mock_stream.return_value = self._empty_async_gen() + + generator = _astream_workflow_generator( + messages=[{"role": "user", "content": "Hello"}], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + enable_web_search=True, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + enable_clarification=False, + max_clarification_rounds=3, + ) + + async for _ in generator: + pass + + # Verify global checkpointer was assigned to graph + assert mock_graph.checkpointer == mock_checkpointer + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test", + }, + ) + @patch("src.server.app.graph") + async def test_astream_falls_back_to_per_request_postgresql(self, mock_graph): + """Verify fallback to per-request connection when _pg_checkpointer is None.""" + mock_pool_instance = MagicMock() + mock_checkpointer = MagicMock() + mock_checkpointer.setup = AsyncMock() + + async def mock_astream(*args, **kwargs): + yield ("agent1", "step1", {"test": "data"}) + + mock_graph.astream = mock_astream + + with ( + patch("src.server.app._pg_checkpointer", None), + patch("src.server.app._pg_pool", None), + patch("src.server.app._process_initial_messages"), + patch("src.server.app.AsyncConnectionPool") as mock_pool_class, + patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer), + patch("src.server.app._stream_graph_events") as mock_stream, + ): + mock_pool_class.return_value.__aenter__ = AsyncMock(return_value=mock_pool_instance) + mock_pool_class.return_value.__aexit__ = AsyncMock() + mock_stream.return_value = self._empty_async_gen() + + generator = _astream_workflow_generator( + messages=[{"role": "user", "content": "Hello"}], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + enable_web_search=True, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + enable_clarification=False, + max_clarification_rounds=3, + ) + + async for _ in generator: + pass + + # Verify per-request connection pool was created + mock_pool_class.assert_called_once() + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test", + }, + ) + @patch("src.server.app.graph") + async def test_astream_uses_global_mongodb_pool_when_available(self, mock_graph): + """Verify global _mongo_checkpointer is used when available.""" + mock_checkpointer = MagicMock() + + async def mock_astream(*args, **kwargs): + yield ("agent1", "step1", {"test": "data"}) + + mock_graph.astream = mock_astream + + with ( + patch("src.server.app._mongo_checkpointer", mock_checkpointer), + patch("src.server.app._mongo_client", MagicMock()), + patch("src.server.app._process_initial_messages"), + patch("src.server.app._stream_graph_events") as mock_stream, + ): + mock_stream.return_value = self._empty_async_gen() + + generator = _astream_workflow_generator( + messages=[{"role": "user", "content": "Hello"}], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + enable_web_search=True, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + enable_clarification=False, + max_clarification_rounds=3, + ) + + async for _ in generator: + pass + + # Verify global checkpointer was assigned to graph + assert mock_graph.checkpointer == mock_checkpointer + + @pytest.mark.asyncio + @patch.dict( + os.environ, + { + "LANGGRAPH_CHECKPOINT_SAVER": "true", + "LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test", + }, + ) + @patch("src.server.app.graph") + async def test_astream_falls_back_to_per_request_mongodb(self, mock_graph): + """Verify fallback to per-request connection when _mongo_checkpointer is None.""" + mock_checkpointer = MagicMock() + + async def mock_astream(*args, **kwargs): + yield ("agent1", "step1", {"test": "data"}) + + mock_graph.astream = mock_astream + + with ( + patch("src.server.app._mongo_checkpointer", None), + patch("src.server.app._mongo_client", None), + patch("src.server.app._process_initial_messages"), + patch("src.server.app.AsyncMongoDBSaver") as mock_saver_class, + patch("src.server.app._stream_graph_events") as mock_stream, + ): + mock_saver_class.from_conn_string.return_value.__aenter__ = AsyncMock( + return_value=mock_checkpointer + ) + mock_saver_class.from_conn_string.return_value.__aexit__ = AsyncMock() + mock_stream.return_value = self._empty_async_gen() + + generator = _astream_workflow_generator( + messages=[{"role": "user", "content": "Hello"}], + thread_id="test_thread", + resources=[], + max_plan_iterations=3, + max_step_num=10, + max_search_results=5, + auto_accepted_plan=True, + interrupt_feedback="", + mcp_settings={}, + enable_background_investigation=False, + enable_web_search=True, + report_style=ReportStyle.ACADEMIC, + enable_deep_thinking=False, + enable_clarification=False, + max_clarification_rounds=3, + ) + + async for _ in generator: + pass + + # Verify per-request MongoDB saver was created + mock_saver_class.from_conn_string.assert_called_once() + + async def _empty_async_gen(self): + """Helper to create an empty async generator.""" + if False: + yield