feat: support dify in rag module (#550)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Chayton Bai
2025-09-16 20:30:45 +08:00
committed by GitHub
parent 5085bf8ee9
commit 7694bb5d72
19 changed files with 407 additions and 87 deletions

View File

@@ -1,81 +1,85 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
import tempfile
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, patch
from typing import Dict, Any, Optional
import psycopg
import pytest
class PostgreSQLMockInstance:
"""Utility class for managing PostgreSQL mock instances."""
def __init__(self, database_name: str = "test_db"):
self.database_name = database_name
self.temp_dir: Optional[Path] = None
self.mock_connection: Optional[MagicMock] = None
self.mock_data: Dict[str, Any] = {}
self._setup_mock_data()
def _setup_mock_data(self):
"""Initialize mock data storage."""
self.mock_data = {
"chat_streams": {}, # thread_id -> record
"table_exists": False,
"connection_active": True
"connection_active": True,
}
def connect(self) -> MagicMock:
"""Create a mock PostgreSQL connection."""
self.mock_connection = MagicMock()
self._setup_mock_methods()
return self.mock_connection
def _setup_mock_methods(self):
"""Setup mock methods for PostgreSQL operations."""
if not self.mock_connection:
return
# Mock cursor context manager
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
# Setup cursor operations
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
mock_cursor.rowcount = 0
# Setup connection operations
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
self.mock_connection.commit = MagicMock()
self.mock_connection.rollback = MagicMock()
self.mock_connection.close = MagicMock()
# Store cursor for external access
self._mock_cursor = mock_cursor
def _mock_execute(self, sql: str, params=None):
"""Mock SQL execution."""
sql_upper = sql.upper().strip()
if "CREATE TABLE" in sql_upper:
self.mock_data["table_exists"] = True
self._mock_cursor.rowcount = 0
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
# Mock SELECT query
if params and len(params) > 0:
thread_id = params[0]
if thread_id in self.mock_data["chat_streams"]:
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id]
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
thread_id
]
else:
self._mock_cursor._fetch_result = None
else:
self._mock_cursor._fetch_result = None
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
# Mock UPDATE query
if params and len(params) >= 2:
@@ -84,12 +88,12 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
"messages": messages,
}
self._mock_cursor.rowcount = 1
else:
self._mock_cursor.rowcount = 0
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
# Mock INSERT query
if params and len(params) >= 2:
@@ -97,30 +101,30 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
"messages": messages,
}
self._mock_cursor.rowcount = 1
def _mock_fetchone(self):
"""Mock fetchone operation."""
return getattr(self._mock_cursor, '_fetch_result', None)
return getattr(self._mock_cursor, "_fetch_result", None)
def disconnect(self):
"""Cleanup mock connection."""
if self.mock_connection:
self.mock_connection.close()
self._setup_mock_data() # Reset data
def reset_data(self):
"""Reset all mock data."""
self._setup_mock_data()
def get_table_count(self, table_name: str) -> int:
"""Get record count in a table."""
if table_name == "chat_streams":
return len(self.mock_data["chat_streams"])
return 0
def create_test_data(self, table_name: str, records: list):
"""Insert test data into a table."""
if table_name == "chat_streams":
@@ -129,6 +133,7 @@ class PostgreSQLMockInstance:
if thread_id:
self.mock_data["chat_streams"][thread_id] = record
@pytest.fixture
def mock_postgresql():
"""Create a PostgreSQL mock instance."""
@@ -137,6 +142,7 @@ def mock_postgresql():
yield instance
instance.disconnect()
@pytest.fixture
def clean_mock_postgresql():
"""Create a clean PostgreSQL mock instance that resets between tests."""
@@ -144,4 +150,4 @@ def clean_mock_postgresql():
instance.connect()
instance.reset_data()
yield instance
instance.disconnect()
instance.disconnect()

View File

@@ -2,15 +2,18 @@
# SPDX-License-Identifier: MIT
import os
import pytest
from unittest.mock import MagicMock, patch
import mongomock
from unittest.mock import patch, MagicMock
import src.graph.checkpoint as checkpoint
import pytest
from postgres_mock_utils import PostgreSQLMockInstance
import src.graph.checkpoint as checkpoint
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def has_real_db_connection():
# Check the environment if the MongoDB server is available
enabled = os.getenv("DB_TESTS_ENABLED", "false")
@@ -18,27 +21,28 @@ def has_real_db_connection():
return True
return False
def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
with patch('psycopg.connect') as mock_connect:
with patch("psycopg.connect") as mock_connect:
# Setup mock PostgreSQL connection
pg_mock = PostgreSQLMockInstance()
mock_connect.return_value = pg_mock.connect()
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
def test_process_stream_partial_buffer_mongo():
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo():
values = [it.dict()["value"] for it in items]
assert "hello" in values
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
@@ -116,7 +123,9 @@ def test_persist_postgresql_local_db():
assert result is True
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager(
@@ -151,40 +160,42 @@ def test_persist_not_attempted_when_saver_disabled():
def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Verify data was persisted in mock
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": thread_id})
assert doc is not None
assert doc["messages"] == messages
# Simulate a message with existing thread
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
# Verify update worked
doc = collection.find_one({"thread_id": thread_id})
assert doc["messages"] == ["Another message."]
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="MongoDB server is not available"
)
def test_persist_mongodb_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
@@ -239,11 +250,11 @@ def test_unsupported_db_uri_scheme():
def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -263,7 +274,7 @@ def test_process_stream_with_interrupt_finish_reason():
)
is True
)
# Verify persistence occurred
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "int_test"})
@@ -395,7 +406,7 @@ def test_multiple_threads_isolation():
def test_mongodb_insert_and_update_paths():
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -404,7 +415,7 @@ def test_mongodb_insert_and_update_paths():
# Insert success (new thread)
assert manager._persist_to_mongodb("th1", ["message1"]) is True
# Verify insert worked
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "th1"})
@@ -413,7 +424,7 @@ def test_mongodb_insert_and_update_paths():
# Update success (existing thread)
assert manager._persist_to_mongodb("th1", ["message2"]) is True
# Verify update worked
doc = collection.find_one({"thread_id": "th1"})
assert doc["messages"] == ["message2"]
@@ -421,9 +432,9 @@ def test_mongodb_insert_and_update_paths():
# Test error case by mocking collection methods
original_find_one = collection.find_one
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
assert manager._persist_to_mongodb("th2", ["message"]) is False
# Restore original method
collection.find_one = original_find_one
@@ -591,19 +602,19 @@ def test_context_manager_calls_close(monkeypatch):
def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
# Success path with mongomock
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
# Failure path
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_mongo_client.side_effect = RuntimeError("Connection failed")
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None