mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-24 22:54:46 +08:00
* test: add unit tests for global connection pool (Issue #778) - Add TestLifespanFunction class with 9 tests for lifespan management: - PostgreSQL/MongoDB pool initialization success/failure - Cleanup on shutdown - Skip initialization when not configured - Add TestGlobalConnectionPoolUsage class with 4 tests: - Using global pools when available - Fallback to per-request connections - Fix missing dict_row import in app.py (bug from PR #757) * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
|||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
from langgraph.store.memory import InMemoryStore
|
from langgraph.store.memory import InMemoryStore
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
from psycopg.rows import dict_row
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
from src.config.configuration import get_recursion_limit
|
from src.config.configuration import get_recursion_limit
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, mock_open, patch
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@@ -1129,3 +1130,472 @@ class TestCreateInterruptEvent:
|
|||||||
# Verify complex value is included (will be serialized as JSON)
|
# Verify complex value is included (will be serialized as JSON)
|
||||||
assert '"id": "int-complex"' in result
|
assert '"id": "int-complex"' in result
|
||||||
assert "Research AI" in result or "plan" in result
|
assert "Research AI" in result or "plan" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestLifespanFunction:
|
||||||
|
"""Tests for the lifespan function and global connection pool management (Issue #778).
|
||||||
|
|
||||||
|
These tests verify correct initialization, error handling, and cleanup behavior
|
||||||
|
for PostgreSQL and MongoDB global connection pools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(os.environ, {"LANGGRAPH_CHECKPOINT_SAVER": "false"})
|
||||||
|
async def test_lifespan_skips_initialization_when_checkpoint_not_configured(self):
|
||||||
|
"""Verify no pool initialization when LANGGRAPH_CHECKPOINT_SAVER=False."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
|
||||||
|
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_pg_pool.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{"LANGGRAPH_CHECKPOINT_SAVER": "true", "LANGGRAPH_CHECKPOINT_DB_URL": ""},
|
||||||
|
)
|
||||||
|
async def test_lifespan_skips_initialization_when_url_empty(self):
|
||||||
|
"""Verify no pool initialization when checkpoint URL is empty."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
|
||||||
|
with patch("src.server.app.AsyncConnectionPool") as mock_pg_pool:
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_pg_pool.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||||
|
"PG_POOL_MIN_SIZE": "2",
|
||||||
|
"PG_POOL_MAX_SIZE": "10",
|
||||||
|
"PG_POOL_TIMEOUT": "30",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def test_lifespan_postgresql_pool_initialization_success(self):
|
||||||
|
"""Test successful PostgreSQL connection pool initialization."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.open = AsyncMock()
|
||||||
|
mock_pool.close = AsyncMock()
|
||||||
|
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
mock_checkpointer.setup = AsyncMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
||||||
|
):
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_pool.open.assert_called_once()
|
||||||
|
mock_checkpointer.setup.assert_called_once()
|
||||||
|
mock_pool.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def test_lifespan_postgresql_pool_initialization_failure(self):
|
||||||
|
"""Verify RuntimeError raised when PostgreSQL pool initialization fails."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.open = AsyncMock(
|
||||||
|
side_effect=Exception("Connection refused")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("src.server.app.AsyncConnectionPool", return_value=mock_pool):
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "PostgreSQL" in str(exc_info.value) or "initialization failed" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||||
|
"MONGO_MIN_POOL_SIZE": "2",
|
||||||
|
"MONGO_MAX_POOL_SIZE": "10",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def test_lifespan_mongodb_pool_initialization_success(self):
|
||||||
|
"""Test successful MongoDB connection pool initialization."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.close = MagicMock()
|
||||||
|
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
mock_checkpointer.setup = AsyncMock()
|
||||||
|
|
||||||
|
# Create a mock motor module
|
||||||
|
mock_motor_asyncio = MagicMock()
|
||||||
|
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
|
||||||
|
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
|
||||||
|
):
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_checkpointer.setup.assert_called_once()
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def test_lifespan_mongodb_import_error(self):
|
||||||
|
"""Verify RuntimeError when motor package is missing."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"motor": None, "motor.motor_asyncio": None}):
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "motor" in str(exc_info.value).lower() or "MongoDB" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def test_lifespan_mongodb_connection_failure(self):
|
||||||
|
"""Verify RuntimeError on MongoDB connection failure."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
|
||||||
|
# Create a mock motor module that raises an exception
|
||||||
|
mock_motor_asyncio = MagicMock()
|
||||||
|
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(
|
||||||
|
side_effect=Exception("Connection refused")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}):
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "MongoDB" in str(exc_info.value) or "initialized" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def test_lifespan_postgresql_cleanup_on_shutdown(self):
|
||||||
|
"""Verify PostgreSQL pool.close() is called during shutdown."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_pool.open = AsyncMock()
|
||||||
|
mock_pool.close = AsyncMock()
|
||||||
|
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
mock_checkpointer.setup = AsyncMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.server.app.AsyncConnectionPool", return_value=mock_pool),
|
||||||
|
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
||||||
|
):
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
# Verify pool is open during app lifetime
|
||||||
|
mock_pool.open.assert_called_once()
|
||||||
|
|
||||||
|
# Verify pool is closed after context exit
|
||||||
|
mock_pool.close.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def test_lifespan_mongodb_cleanup_on_shutdown(self):
|
||||||
|
"""Verify MongoDB client.close() is called during shutdown."""
|
||||||
|
from src.server.app import lifespan
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.close = MagicMock()
|
||||||
|
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
mock_checkpointer.setup = AsyncMock()
|
||||||
|
|
||||||
|
# Create a mock motor module
|
||||||
|
mock_motor_asyncio = MagicMock()
|
||||||
|
mock_motor_asyncio.AsyncIOMotorClient = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.dict("sys.modules", {"motor": MagicMock(), "motor.motor_asyncio": mock_motor_asyncio}),
|
||||||
|
patch("src.server.app.AsyncMongoDBSaver", return_value=mock_checkpointer),
|
||||||
|
):
|
||||||
|
async with lifespan(mock_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify client is closed after context exit
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalConnectionPoolUsage:
|
||||||
|
"""Tests for _astream_workflow_generator using global connection pools (Issue #778).
|
||||||
|
|
||||||
|
These tests verify that the workflow generator correctly uses global pools
|
||||||
|
when available and falls back to per-request connections when not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@patch("src.server.app.graph")
|
||||||
|
async def test_astream_uses_global_postgresql_pool_when_available(self, mock_graph):
|
||||||
|
"""Verify global _pg_checkpointer is used when available."""
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
|
||||||
|
async def mock_astream(*args, **kwargs):
|
||||||
|
yield ("agent1", "step1", {"test": "data"})
|
||||||
|
|
||||||
|
mock_graph.astream = mock_astream
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.server.app._pg_checkpointer", mock_checkpointer),
|
||||||
|
patch("src.server.app._pg_pool", MagicMock()),
|
||||||
|
patch("src.server.app._process_initial_messages"),
|
||||||
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||||
|
):
|
||||||
|
mock_stream.return_value = self._empty_async_gen()
|
||||||
|
|
||||||
|
generator = _astream_workflow_generator(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
thread_id="test_thread",
|
||||||
|
resources=[],
|
||||||
|
max_plan_iterations=3,
|
||||||
|
max_step_num=10,
|
||||||
|
max_search_results=5,
|
||||||
|
auto_accepted_plan=True,
|
||||||
|
interrupt_feedback="",
|
||||||
|
mcp_settings={},
|
||||||
|
enable_background_investigation=False,
|
||||||
|
enable_web_search=True,
|
||||||
|
report_style=ReportStyle.ACADEMIC,
|
||||||
|
enable_deep_thinking=False,
|
||||||
|
enable_clarification=False,
|
||||||
|
max_clarification_rounds=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in generator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify global checkpointer was assigned to graph
|
||||||
|
assert mock_graph.checkpointer == mock_checkpointer
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "postgresql://localhost:5432/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@patch("src.server.app.graph")
|
||||||
|
async def test_astream_falls_back_to_per_request_postgresql(self, mock_graph):
|
||||||
|
"""Verify fallback to per-request connection when _pg_checkpointer is None."""
|
||||||
|
mock_pool_instance = MagicMock()
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
mock_checkpointer.setup = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_astream(*args, **kwargs):
|
||||||
|
yield ("agent1", "step1", {"test": "data"})
|
||||||
|
|
||||||
|
mock_graph.astream = mock_astream
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.server.app._pg_checkpointer", None),
|
||||||
|
patch("src.server.app._pg_pool", None),
|
||||||
|
patch("src.server.app._process_initial_messages"),
|
||||||
|
patch("src.server.app.AsyncConnectionPool") as mock_pool_class,
|
||||||
|
patch("src.server.app.AsyncPostgresSaver", return_value=mock_checkpointer),
|
||||||
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||||
|
):
|
||||||
|
mock_pool_class.return_value.__aenter__ = AsyncMock(return_value=mock_pool_instance)
|
||||||
|
mock_pool_class.return_value.__aexit__ = AsyncMock()
|
||||||
|
mock_stream.return_value = self._empty_async_gen()
|
||||||
|
|
||||||
|
generator = _astream_workflow_generator(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
thread_id="test_thread",
|
||||||
|
resources=[],
|
||||||
|
max_plan_iterations=3,
|
||||||
|
max_step_num=10,
|
||||||
|
max_search_results=5,
|
||||||
|
auto_accepted_plan=True,
|
||||||
|
interrupt_feedback="",
|
||||||
|
mcp_settings={},
|
||||||
|
enable_background_investigation=False,
|
||||||
|
enable_web_search=True,
|
||||||
|
report_style=ReportStyle.ACADEMIC,
|
||||||
|
enable_deep_thinking=False,
|
||||||
|
enable_clarification=False,
|
||||||
|
max_clarification_rounds=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in generator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify per-request connection pool was created
|
||||||
|
mock_pool_class.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@patch("src.server.app.graph")
|
||||||
|
async def test_astream_uses_global_mongodb_pool_when_available(self, mock_graph):
|
||||||
|
"""Verify global _mongo_checkpointer is used when available."""
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
|
||||||
|
async def mock_astream(*args, **kwargs):
|
||||||
|
yield ("agent1", "step1", {"test": "data"})
|
||||||
|
|
||||||
|
mock_graph.astream = mock_astream
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.server.app._mongo_checkpointer", mock_checkpointer),
|
||||||
|
patch("src.server.app._mongo_client", MagicMock()),
|
||||||
|
patch("src.server.app._process_initial_messages"),
|
||||||
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||||
|
):
|
||||||
|
mock_stream.return_value = self._empty_async_gen()
|
||||||
|
|
||||||
|
generator = _astream_workflow_generator(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
thread_id="test_thread",
|
||||||
|
resources=[],
|
||||||
|
max_plan_iterations=3,
|
||||||
|
max_step_num=10,
|
||||||
|
max_search_results=5,
|
||||||
|
auto_accepted_plan=True,
|
||||||
|
interrupt_feedback="",
|
||||||
|
mcp_settings={},
|
||||||
|
enable_background_investigation=False,
|
||||||
|
enable_web_search=True,
|
||||||
|
report_style=ReportStyle.ACADEMIC,
|
||||||
|
enable_deep_thinking=False,
|
||||||
|
enable_clarification=False,
|
||||||
|
max_clarification_rounds=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in generator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify global checkpointer was assigned to graph
|
||||||
|
assert mock_graph.checkpointer == mock_checkpointer
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{
|
||||||
|
"LANGGRAPH_CHECKPOINT_SAVER": "true",
|
||||||
|
"LANGGRAPH_CHECKPOINT_DB_URL": "mongodb://localhost:27017/test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@patch("src.server.app.graph")
|
||||||
|
async def test_astream_falls_back_to_per_request_mongodb(self, mock_graph):
|
||||||
|
"""Verify fallback to per-request connection when _mongo_checkpointer is None."""
|
||||||
|
mock_checkpointer = MagicMock()
|
||||||
|
|
||||||
|
async def mock_astream(*args, **kwargs):
|
||||||
|
yield ("agent1", "step1", {"test": "data"})
|
||||||
|
|
||||||
|
mock_graph.astream = mock_astream
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.server.app._mongo_checkpointer", None),
|
||||||
|
patch("src.server.app._mongo_client", None),
|
||||||
|
patch("src.server.app._process_initial_messages"),
|
||||||
|
patch("src.server.app.AsyncMongoDBSaver") as mock_saver_class,
|
||||||
|
patch("src.server.app._stream_graph_events") as mock_stream,
|
||||||
|
):
|
||||||
|
mock_saver_class.from_conn_string.return_value.__aenter__ = AsyncMock(
|
||||||
|
return_value=mock_checkpointer
|
||||||
|
)
|
||||||
|
mock_saver_class.from_conn_string.return_value.__aexit__ = AsyncMock()
|
||||||
|
mock_stream.return_value = self._empty_async_gen()
|
||||||
|
|
||||||
|
generator = _astream_workflow_generator(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
thread_id="test_thread",
|
||||||
|
resources=[],
|
||||||
|
max_plan_iterations=3,
|
||||||
|
max_step_num=10,
|
||||||
|
max_search_results=5,
|
||||||
|
auto_accepted_plan=True,
|
||||||
|
interrupt_feedback="",
|
||||||
|
mcp_settings={},
|
||||||
|
enable_background_investigation=False,
|
||||||
|
enable_web_search=True,
|
||||||
|
report_style=ReportStyle.ACADEMIC,
|
||||||
|
enable_deep_thinking=False,
|
||||||
|
enable_clarification=False,
|
||||||
|
max_clarification_rounds=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for _ in generator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Verify per-request MongoDB saver was created
|
||||||
|
mock_saver_class.from_conn_string.assert_called_once()
|
||||||
|
|
||||||
|
async def _empty_async_gen(self):
|
||||||
|
"""Helper to create an empty async generator."""
|
||||||
|
if False:
|
||||||
|
yield
|
||||||
|
|||||||
Reference in New Issue
Block a user