feat: creating mogodb and postgres mock instance in checkpoint test (#561)

* fix: using mongomock for the checkpoint test

* Add postgres mock setting to the unit test

* Added utils file of postgres_mock_utils

* fixed the runtime loading error of deerflow server
This commit is contained in:
Willem Jiang
2025-09-09 22:49:11 +08:00
committed by GitHub
parent 7138ba36bc
commit 4c17d88029
6 changed files with 463 additions and 151 deletions

View File

@@ -0,0 +1,147 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
import tempfile
import shutil
from pathlib import Path
from unittest.mock import MagicMock, patch
from typing import Dict, Any, Optional
import psycopg
class PostgreSQLMockInstance:
"""Utility class for managing PostgreSQL mock instances."""
def __init__(self, database_name: str = "test_db"):
self.database_name = database_name
self.temp_dir: Optional[Path] = None
self.mock_connection: Optional[MagicMock] = None
self.mock_data: Dict[str, Any] = {}
self._setup_mock_data()
def _setup_mock_data(self):
"""Initialize mock data storage."""
self.mock_data = {
"chat_streams": {}, # thread_id -> record
"table_exists": False,
"connection_active": True
}
def connect(self) -> MagicMock:
"""Create a mock PostgreSQL connection."""
self.mock_connection = MagicMock()
self._setup_mock_methods()
return self.mock_connection
def _setup_mock_methods(self):
"""Setup mock methods for PostgreSQL operations."""
if not self.mock_connection:
return
# Mock cursor context manager
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
# Setup cursor operations
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
mock_cursor.rowcount = 0
# Setup connection operations
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
self.mock_connection.commit = MagicMock()
self.mock_connection.rollback = MagicMock()
self.mock_connection.close = MagicMock()
# Store cursor for external access
self._mock_cursor = mock_cursor
def _mock_execute(self, sql: str, params=None):
"""Mock SQL execution."""
sql_upper = sql.upper().strip()
if "CREATE TABLE" in sql_upper:
self.mock_data["table_exists"] = True
self._mock_cursor.rowcount = 0
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
# Mock SELECT query
if params and len(params) > 0:
thread_id = params[0]
if thread_id in self.mock_data["chat_streams"]:
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id]
else:
self._mock_cursor._fetch_result = None
else:
self._mock_cursor._fetch_result = None
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
# Mock UPDATE query
if params and len(params) >= 2:
messages, thread_id = params[0], params[1]
if thread_id in self.mock_data["chat_streams"]:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
}
self._mock_cursor.rowcount = 1
else:
self._mock_cursor.rowcount = 0
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
# Mock INSERT query
if params and len(params) >= 2:
thread_id, messages = params[0], params[1]
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
}
self._mock_cursor.rowcount = 1
def _mock_fetchone(self):
"""Mock fetchone operation."""
return getattr(self._mock_cursor, '_fetch_result', None)
def disconnect(self):
"""Cleanup mock connection."""
if self.mock_connection:
self.mock_connection.close()
self._setup_mock_data() # Reset data
def reset_data(self):
"""Reset all mock data."""
self._setup_mock_data()
def get_table_count(self, table_name: str) -> int:
"""Get record count in a table."""
if table_name == "chat_streams":
return len(self.mock_data["chat_streams"])
return 0
def create_test_data(self, table_name: str, records: list):
"""Insert test data into a table."""
if table_name == "chat_streams":
for record in records:
thread_id = record.get("thread_id")
if thread_id:
self.mock_data["chat_streams"][thread_id] = record
@pytest.fixture
def mock_postgresql():
"""Create a PostgreSQL mock instance."""
instance = PostgreSQLMockInstance()
instance.connect()
yield instance
instance.disconnect()
@pytest.fixture
def clean_mock_postgresql():
"""Create a clean PostgreSQL mock instance that resets between tests."""
instance = PostgreSQLMockInstance()
instance.connect()
instance.reset_data()
yield instance
instance.disconnect()

View File

@@ -1,17 +1,32 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import pytest
import mongomock
from unittest.mock import patch, MagicMock
import src.graph.checkpoint as checkpoint
from postgres_mock_utils import PostgreSQLMockInstance
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def has_real_db_connection():
# Check the environment if the MongoDB server is available
enabled = os.getenv("DB_TESTS_ENABLED", "false")
if enabled.lower() == "true":
return True
return False
def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
with patch('psycopg.connect') as mock_connect:
# Setup mock PostgreSQL connection
pg_mock = PostgreSQLMockInstance()
mock_connect.return_value = pg_mock.connect()
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
@@ -19,12 +34,17 @@ def test_with_local_postgres_db():
def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
def test_init_without_checkpoint_saver():
@@ -58,30 +78,25 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
assert "hello" in values
def test_process_stream_partial_buffer_mongo(monkeypatch):
"""Partial chunks should be buffered; Mongo init is stubbed to no-op."""
# Patch Mongo init to no-op for speed
def _no_mongo(self):
self.mongo_client = None
self.mongo_db = None
monkeypatch.setattr(
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
)
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
assert result is True
# Verify the chunk was stored in the in-memory store
items = manager.store.search(("messages", "t2"), limit=10)
values = [it.dict()["value"] for it in items]
assert "hello" in values
def test_process_stream_partial_buffer_mongo():
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
assert result is True
# Verify the chunk was stored in the in-memory store
items = manager.store.search(("messages", "t2"), limit=10)
values = [it.dict()["value"] for it in items]
assert "hello" in values
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
@@ -101,7 +116,8 @@ def test_persist_postgresql_local_db():
assert result is True
def test_persist_postgresql_called_with_aggregated_chunks(monkeypatch):
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
def test_persist_postgresql_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
@@ -134,24 +150,42 @@ def test_persist_not_attempted_when_saver_disabled():
def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a local MongoDB."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Simulate a message with existing thread
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Verify data was persisted in mock
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": thread_id})
assert doc is not None
assert doc["messages"] == messages
# Simulate a message with existing thread
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
# Verify update worked
doc = collection.find_one({"thread_id": thread_id})
assert doc["messages"] == ["Another message."]
def test_persist_mongodb_called_with_aggregated_chunks(monkeypatch):
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
def test_persist_mongodb_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
manager = checkpoint.ChatStreamManager(
@@ -205,25 +239,36 @@ def test_unsupported_db_uri_scheme():
def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
# Add partial message
assert (
manager.process_stream_message(
"int_test", "Interrupted", finish_reason="partial"
# Add partial message
assert (
manager.process_stream_message(
"int_test", "Interrupted", finish_reason="partial"
)
is True
)
is True
)
# Interrupt should trigger persistence
assert (
manager.process_stream_message(
"int_test", " message", finish_reason="interrupt"
# Interrupt should trigger persistence
assert (
manager.process_stream_message(
"int_test", " message", finish_reason="interrupt"
)
is True
)
is True
)
# Verify persistence occurred
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "int_test"})
assert doc is not None
assert doc["messages"] == ["Interrupted", " message"]
def test_postgresql_connection_failure(monkeypatch):
@@ -348,64 +393,39 @@ def test_multiple_threads_isolation():
assert "msg2" not in thread1_values
def test_mongodb_insert_and_update_paths(monkeypatch):
"""Exercise MongoDB insert, update, and exception branches."""
def test_mongodb_insert_and_update_paths():
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
# Fake Mongo classes
class FakeUpdateResult:
def __init__(self, modified_count):
self.modified_count = modified_count
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
class FakeInsertResult:
def __init__(self, inserted_id):
self.inserted_id = inserted_id
# Insert success (new thread)
assert manager._persist_to_mongodb("th1", ["message1"]) is True
# Verify insert worked
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "th1"})
assert doc is not None
assert doc["messages"] == ["message1"]
class FakeCollection:
def __init__(self, mode="insert_success"):
self.mode = mode
# Update success (existing thread)
assert manager._persist_to_mongodb("th1", ["message2"]) is True
# Verify update worked
doc = collection.find_one({"thread_id": "th1"})
assert doc["messages"] == ["message2"]
def find_one(self, query):
if self.mode.startswith("insert"):
return None
return {"thread_id": query["thread_id"]}
def update_one(self, q, s):
if self.mode == "update_success":
return FakeUpdateResult(1)
return FakeUpdateResult(0)
def insert_one(self, doc):
if self.mode == "insert_success":
return FakeInsertResult("ok")
if self.mode == "insert_none":
return FakeInsertResult(None)
raise RuntimeError("boom")
class FakeMongoDB:
def __init__(self, mode):
self.chat_streams = FakeCollection(mode)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Insert success
manager.mongo_db = FakeMongoDB("insert_success")
assert manager._persist_to_mongodb("th1", ["a"]) is True
# Insert returns None id => False
manager.mongo_db = FakeMongoDB("insert_none")
assert manager._persist_to_mongodb("th2", ["a"]) is False
# Insert raises => False
manager.mongo_db = FakeMongoDB("insert_raise")
assert manager._persist_to_mongodb("th3", ["a"]) is False
# Update success
manager.mongo_db = FakeMongoDB("update_success")
assert manager._persist_to_mongodb("th4", ["a"]) is True
# Update modifies 0 => False
manager.mongo_db = FakeMongoDB("update_zero")
assert manager._persist_to_mongodb("th5", ["a"]) is False
# Test error case by mocking collection methods
original_find_one = collection.find_one
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
assert manager._persist_to_mongodb("th2", ["message"]) is False
# Restore original method
collection.find_one = original_find_one
def test_postgresql_insert_update_and_error_paths():
@@ -570,38 +590,23 @@ def test_context_manager_calls_close(monkeypatch):
def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with a valid client and fail gracefully otherwise."""
class FakeAdmin:
def command(self, name):
assert name == "ping"
class DummyDB:
pass
class FakeClient:
def __init__(self, uri):
self.uri = uri
self.admin = FakeAdmin()
self.checkpointing_db = DummyDB()
def close(self):
pass
# Success path
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
# Success path with mongomock
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
# Failure path
class Boom:
def __init__(self, uri):
raise RuntimeError("fail connect")
monkeypatch.setattr(checkpoint, "MongoClient", Boom)
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
mock_mongo_client.side_effect = RuntimeError("Connection failed")
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None
def test_init_postgresql_calls_connect_and_create_table(monkeypatch):