mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-18 20:14:44 +08:00
feat: creating mogodb and postgres mock instance in checkpoint test (#561)
* fix: using mongomock for the checkpoint test * Add postgres mock setting to the unit test * Added utils file of postgres_mock_utils * fixed the runtime loading error of deerflow server
This commit is contained in:
@@ -1,17 +1,32 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import mongomock
|
||||
from unittest.mock import patch, MagicMock
|
||||
import src.graph.checkpoint as checkpoint
|
||||
from postgres_mock_utils import PostgreSQLMockInstance
|
||||
|
||||
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
|
||||
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
|
||||
|
||||
def has_real_db_connection():
|
||||
# Check the environment if the MongoDB server is available
|
||||
enabled = os.getenv("DB_TESTS_ENABLED", "false")
|
||||
if enabled.lower() == "true":
|
||||
return True
|
||||
return False
|
||||
|
||||
def test_with_local_postgres_db():
|
||||
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
with patch('psycopg.connect') as mock_connect:
|
||||
# Setup mock PostgreSQL connection
|
||||
pg_mock = PostgreSQLMockInstance()
|
||||
mock_connect.return_value = pg_mock.connect()
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
)
|
||||
assert manager.postgres_conn is not None
|
||||
assert manager.mongo_client is None
|
||||
@@ -19,12 +34,17 @@ def test_with_local_postgres_db():
|
||||
|
||||
def test_with_local_mongo_db():
|
||||
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.mongo_db is not None
|
||||
assert manager.postgres_conn is None
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.mongo_db is not None
|
||||
assert manager.postgres_conn is None
|
||||
|
||||
|
||||
def test_init_without_checkpoint_saver():
|
||||
@@ -58,30 +78,25 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
|
||||
assert "hello" in values
|
||||
|
||||
|
||||
def test_process_stream_partial_buffer_mongo(monkeypatch):
|
||||
"""Partial chunks should be buffered; Mongo init is stubbed to no-op."""
|
||||
|
||||
# Patch Mongo init to no-op for speed
|
||||
def _no_mongo(self):
|
||||
self.mongo_client = None
|
||||
self.mongo_db = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
checkpoint.ChatStreamManager, "_init_mongodb", _no_mongo, raising=True
|
||||
)
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
|
||||
assert result is True
|
||||
# Verify the chunk was stored in the in-memory store
|
||||
items = manager.store.search(("messages", "t2"), limit=10)
|
||||
values = [it.dict()["value"] for it in items]
|
||||
assert "hello" in values
|
||||
def test_process_stream_partial_buffer_mongo():
|
||||
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
result = manager.process_stream_message("t2", "hello", finish_reason="partial")
|
||||
assert result is True
|
||||
# Verify the chunk was stored in the in-memory store
|
||||
items = manager.store.search(("messages", "t2"), limit=10)
|
||||
values = [it.dict()["value"] for it in items]
|
||||
assert "hello" in values
|
||||
|
||||
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
|
||||
def test_persist_postgresql_local_db():
|
||||
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
@@ -101,7 +116,8 @@ def test_persist_postgresql_local_db():
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_persist_postgresql_called_with_aggregated_chunks(monkeypatch):
|
||||
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
|
||||
def test_persist_postgresql_called_with_aggregated_chunks():
|
||||
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
@@ -134,24 +150,42 @@ def test_persist_not_attempted_when_saver_disabled():
|
||||
|
||||
|
||||
def test_persist_mongodb_local_db():
|
||||
"""Ensure that the ChatStreamManager can persist to a local MongoDB."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.mongo_db is not None
|
||||
assert manager.postgres_conn is None
|
||||
# Simulate a message to persist
|
||||
thread_id = "test_thread"
|
||||
messages = ["This is a test message."]
|
||||
result = manager._persist_to_mongodb(thread_id, messages)
|
||||
assert result is True
|
||||
# Simulate a message with existing thread
|
||||
result = manager._persist_to_mongodb(thread_id, ["Another message."])
|
||||
assert result is True
|
||||
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.mongo_db is not None
|
||||
assert manager.postgres_conn is None
|
||||
|
||||
# Simulate a message to persist
|
||||
thread_id = "test_thread"
|
||||
messages = ["This is a test message."]
|
||||
result = manager._persist_to_mongodb(thread_id, messages)
|
||||
assert result is True
|
||||
|
||||
# Verify data was persisted in mock
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == messages
|
||||
|
||||
# Simulate a message with existing thread
|
||||
result = manager._persist_to_mongodb(thread_id, ["Another message."])
|
||||
assert result is True
|
||||
|
||||
# Verify update worked
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
assert doc["messages"] == ["Another message."]
|
||||
|
||||
|
||||
def test_persist_mongodb_called_with_aggregated_chunks(monkeypatch):
|
||||
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
|
||||
def test_persist_mongodb_called_with_aggregated_chunks():
|
||||
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
@@ -205,25 +239,36 @@ def test_unsupported_db_uri_scheme():
|
||||
|
||||
def test_process_stream_with_interrupt_finish_reason():
|
||||
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
|
||||
# Add partial message
|
||||
assert (
|
||||
manager.process_stream_message(
|
||||
"int_test", "Interrupted", finish_reason="partial"
|
||||
# Add partial message
|
||||
assert (
|
||||
manager.process_stream_message(
|
||||
"int_test", "Interrupted", finish_reason="partial"
|
||||
)
|
||||
is True
|
||||
)
|
||||
is True
|
||||
)
|
||||
# Interrupt should trigger persistence
|
||||
assert (
|
||||
manager.process_stream_message(
|
||||
"int_test", " message", finish_reason="interrupt"
|
||||
# Interrupt should trigger persistence
|
||||
assert (
|
||||
manager.process_stream_message(
|
||||
"int_test", " message", finish_reason="interrupt"
|
||||
)
|
||||
is True
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
# Verify persistence occurred
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": "int_test"})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == ["Interrupted", " message"]
|
||||
|
||||
|
||||
def test_postgresql_connection_failure(monkeypatch):
|
||||
@@ -348,64 +393,39 @@ def test_multiple_threads_isolation():
|
||||
assert "msg2" not in thread1_values
|
||||
|
||||
|
||||
def test_mongodb_insert_and_update_paths(monkeypatch):
|
||||
"""Exercise MongoDB insert, update, and exception branches."""
|
||||
def test_mongodb_insert_and_update_paths():
|
||||
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
# Fake Mongo classes
|
||||
class FakeUpdateResult:
|
||||
def __init__(self, modified_count):
|
||||
self.modified_count = modified_count
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
|
||||
class FakeInsertResult:
|
||||
def __init__(self, inserted_id):
|
||||
self.inserted_id = inserted_id
|
||||
# Insert success (new thread)
|
||||
assert manager._persist_to_mongodb("th1", ["message1"]) is True
|
||||
|
||||
# Verify insert worked
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": "th1"})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == ["message1"]
|
||||
|
||||
class FakeCollection:
|
||||
def __init__(self, mode="insert_success"):
|
||||
self.mode = mode
|
||||
# Update success (existing thread)
|
||||
assert manager._persist_to_mongodb("th1", ["message2"]) is True
|
||||
|
||||
# Verify update worked
|
||||
doc = collection.find_one({"thread_id": "th1"})
|
||||
assert doc["messages"] == ["message2"]
|
||||
|
||||
def find_one(self, query):
|
||||
if self.mode.startswith("insert"):
|
||||
return None
|
||||
return {"thread_id": query["thread_id"]}
|
||||
|
||||
def update_one(self, q, s):
|
||||
if self.mode == "update_success":
|
||||
return FakeUpdateResult(1)
|
||||
return FakeUpdateResult(0)
|
||||
|
||||
def insert_one(self, doc):
|
||||
if self.mode == "insert_success":
|
||||
return FakeInsertResult("ok")
|
||||
if self.mode == "insert_none":
|
||||
return FakeInsertResult(None)
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class FakeMongoDB:
|
||||
def __init__(self, mode):
|
||||
self.chat_streams = FakeCollection(mode)
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
|
||||
# Insert success
|
||||
manager.mongo_db = FakeMongoDB("insert_success")
|
||||
assert manager._persist_to_mongodb("th1", ["a"]) is True
|
||||
|
||||
# Insert returns None id => False
|
||||
manager.mongo_db = FakeMongoDB("insert_none")
|
||||
assert manager._persist_to_mongodb("th2", ["a"]) is False
|
||||
|
||||
# Insert raises => False
|
||||
manager.mongo_db = FakeMongoDB("insert_raise")
|
||||
assert manager._persist_to_mongodb("th3", ["a"]) is False
|
||||
|
||||
# Update success
|
||||
manager.mongo_db = FakeMongoDB("update_success")
|
||||
assert manager._persist_to_mongodb("th4", ["a"]) is True
|
||||
|
||||
# Update modifies 0 => False
|
||||
manager.mongo_db = FakeMongoDB("update_zero")
|
||||
assert manager._persist_to_mongodb("th5", ["a"]) is False
|
||||
# Test error case by mocking collection methods
|
||||
original_find_one = collection.find_one
|
||||
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
|
||||
|
||||
assert manager._persist_to_mongodb("th2", ["message"]) is False
|
||||
|
||||
# Restore original method
|
||||
collection.find_one = original_find_one
|
||||
|
||||
|
||||
def test_postgresql_insert_update_and_error_paths():
|
||||
@@ -570,38 +590,23 @@ def test_context_manager_calls_close(monkeypatch):
|
||||
|
||||
|
||||
def test_init_mongodb_success_and_failure(monkeypatch):
|
||||
"""MongoDB init should succeed with a valid client and fail gracefully otherwise."""
|
||||
|
||||
class FakeAdmin:
|
||||
def command(self, name):
|
||||
assert name == "ping"
|
||||
|
||||
class DummyDB:
|
||||
pass
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, uri):
|
||||
self.uri = uri
|
||||
self.admin = FakeAdmin()
|
||||
self.checkpointing_db = DummyDB()
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
# Success path
|
||||
monkeypatch.setattr(checkpoint, "MongoClient", lambda uri: FakeClient(uri))
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
assert manager.mongo_db is not None
|
||||
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
|
||||
|
||||
# Success path with mongomock
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
assert manager.mongo_db is not None
|
||||
|
||||
# Failure path
|
||||
class Boom:
|
||||
def __init__(self, uri):
|
||||
raise RuntimeError("fail connect")
|
||||
|
||||
monkeypatch.setattr(checkpoint, "MongoClient", Boom)
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
# Should have no mongo_db set on failure
|
||||
assert getattr(manager, "mongo_db", None) is None
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
mock_mongo_client.side_effect = RuntimeError("Connection failed")
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
# Should have no mongo_db set on failure
|
||||
assert getattr(manager, "mongo_db", None) is None
|
||||
|
||||
|
||||
def test_init_postgresql_calls_connect_and_create_table(monkeypatch):
|
||||
|
||||
Reference in New Issue
Block a user