feat: support dify in rag module (#550)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Chayton Bai
2025-09-16 20:30:45 +08:00
committed by GitHub
parent 5085bf8ee9
commit 7694bb5d72
19 changed files with 407 additions and 87 deletions

View File

@@ -2,15 +2,18 @@
# SPDX-License-Identifier: MIT
import os
import pytest
from unittest.mock import MagicMock, patch
import mongomock
from unittest.mock import patch, MagicMock
import src.graph.checkpoint as checkpoint
import pytest
from postgres_mock_utils import PostgreSQLMockInstance
import src.graph.checkpoint as checkpoint
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def has_real_db_connection():
# Check the environment if the MongoDB server is available
enabled = os.getenv("DB_TESTS_ENABLED", "false")
@@ -18,27 +21,28 @@ def has_real_db_connection():
return True
return False
def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
with patch('psycopg.connect') as mock_connect:
with patch("psycopg.connect") as mock_connect:
# Setup mock PostgreSQL connection
pg_mock = PostgreSQLMockInstance()
mock_connect.return_value = pg_mock.connect()
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
def test_process_stream_partial_buffer_mongo():
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo():
values = [it.dict()["value"] for it in items]
assert "hello" in values
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
@@ -116,7 +123,9 @@ def test_persist_postgresql_local_db():
assert result is True
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager(
@@ -151,40 +160,42 @@ def test_persist_not_attempted_when_saver_disabled():
def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Verify data was persisted in mock
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": thread_id})
assert doc is not None
assert doc["messages"] == messages
# Simulate a message with existing thread
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
# Verify update worked
doc = collection.find_one({"thread_id": thread_id})
assert doc["messages"] == ["Another message."]
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="MongoDB server is not available"
)
def test_persist_mongodb_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
@@ -239,11 +250,11 @@ def test_unsupported_db_uri_scheme():
def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -263,7 +274,7 @@ def test_process_stream_with_interrupt_finish_reason():
)
is True
)
# Verify persistence occurred
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "int_test"})
@@ -395,7 +406,7 @@ def test_multiple_threads_isolation():
def test_mongodb_insert_and_update_paths():
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -404,7 +415,7 @@ def test_mongodb_insert_and_update_paths():
# Insert success (new thread)
assert manager._persist_to_mongodb("th1", ["message1"]) is True
# Verify insert worked
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "th1"})
@@ -413,7 +424,7 @@ def test_mongodb_insert_and_update_paths():
# Update success (existing thread)
assert manager._persist_to_mongodb("th1", ["message2"]) is True
# Verify update worked
doc = collection.find_one({"thread_id": "th1"})
assert doc["messages"] == ["message2"]
@@ -421,9 +432,9 @@ def test_mongodb_insert_and_update_paths():
# Test error case by mocking collection methods
original_find_one = collection.find_one
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
assert manager._persist_to_mongodb("th2", ["message"]) is False
# Restore original method
collection.find_one = original_find_one
@@ -591,19 +602,19 @@ def test_context_manager_calls_close(monkeypatch):
def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
# Success path with mongomock
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
# Failure path
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_mongo_client.side_effect = RuntimeError("Connection failed")
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None