feat: support dify in rag module (#550)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Chayton Bai
2025-09-16 20:30:45 +08:00
committed by GitHub
parent 5085bf8ee9
commit 7694bb5d72
19 changed files with 407 additions and 87 deletions

View File

@@ -1,81 +1,85 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
import tempfile
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, patch
from typing import Dict, Any, Optional
import psycopg
import pytest
class PostgreSQLMockInstance:
"""Utility class for managing PostgreSQL mock instances."""
def __init__(self, database_name: str = "test_db"):
self.database_name = database_name
self.temp_dir: Optional[Path] = None
self.mock_connection: Optional[MagicMock] = None
self.mock_data: Dict[str, Any] = {}
self._setup_mock_data()
def _setup_mock_data(self):
"""Initialize mock data storage."""
self.mock_data = {
"chat_streams": {}, # thread_id -> record
"table_exists": False,
"connection_active": True
"connection_active": True,
}
def connect(self) -> MagicMock:
"""Create a mock PostgreSQL connection."""
self.mock_connection = MagicMock()
self._setup_mock_methods()
return self.mock_connection
def _setup_mock_methods(self):
"""Setup mock methods for PostgreSQL operations."""
if not self.mock_connection:
return
# Mock cursor context manager
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
# Setup cursor operations
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
mock_cursor.rowcount = 0
# Setup connection operations
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
self.mock_connection.commit = MagicMock()
self.mock_connection.rollback = MagicMock()
self.mock_connection.close = MagicMock()
# Store cursor for external access
self._mock_cursor = mock_cursor
def _mock_execute(self, sql: str, params=None):
"""Mock SQL execution."""
sql_upper = sql.upper().strip()
if "CREATE TABLE" in sql_upper:
self.mock_data["table_exists"] = True
self._mock_cursor.rowcount = 0
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
# Mock SELECT query
if params and len(params) > 0:
thread_id = params[0]
if thread_id in self.mock_data["chat_streams"]:
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id]
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
thread_id
]
else:
self._mock_cursor._fetch_result = None
else:
self._mock_cursor._fetch_result = None
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
# Mock UPDATE query
if params and len(params) >= 2:
@@ -84,12 +88,12 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
"messages": messages,
}
self._mock_cursor.rowcount = 1
else:
self._mock_cursor.rowcount = 0
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
# Mock INSERT query
if params and len(params) >= 2:
@@ -97,30 +101,30 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
"messages": messages,
}
self._mock_cursor.rowcount = 1
def _mock_fetchone(self):
"""Mock fetchone operation."""
return getattr(self._mock_cursor, '_fetch_result', None)
return getattr(self._mock_cursor, "_fetch_result", None)
def disconnect(self):
"""Cleanup mock connection."""
if self.mock_connection:
self.mock_connection.close()
self._setup_mock_data() # Reset data
def reset_data(self):
"""Reset all mock data."""
self._setup_mock_data()
def get_table_count(self, table_name: str) -> int:
"""Get record count in a table."""
if table_name == "chat_streams":
return len(self.mock_data["chat_streams"])
return 0
def create_test_data(self, table_name: str, records: list):
"""Insert test data into a table."""
if table_name == "chat_streams":
@@ -129,6 +133,7 @@ class PostgreSQLMockInstance:
if thread_id:
self.mock_data["chat_streams"][thread_id] = record
@pytest.fixture
def mock_postgresql():
"""Create a PostgreSQL mock instance."""
@@ -137,6 +142,7 @@ def mock_postgresql():
yield instance
instance.disconnect()
@pytest.fixture
def clean_mock_postgresql():
"""Create a clean PostgreSQL mock instance that resets between tests."""
@@ -144,4 +150,4 @@ def clean_mock_postgresql():
instance.connect()
instance.reset_data()
yield instance
instance.disconnect()
instance.disconnect()

View File

@@ -2,15 +2,18 @@
# SPDX-License-Identifier: MIT
import os
import pytest
from unittest.mock import MagicMock, patch
import mongomock
from unittest.mock import patch, MagicMock
import src.graph.checkpoint as checkpoint
import pytest
from postgres_mock_utils import PostgreSQLMockInstance
import src.graph.checkpoint as checkpoint
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def has_real_db_connection():
# Check the environment if the MongoDB server is available
enabled = os.getenv("DB_TESTS_ENABLED", "false")
@@ -18,27 +21,28 @@ def has_real_db_connection():
return True
return False
def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
with patch('psycopg.connect') as mock_connect:
with patch("psycopg.connect") as mock_connect:
# Setup mock PostgreSQL connection
pg_mock = PostgreSQLMockInstance()
mock_connect.return_value = pg_mock.connect()
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
def test_process_stream_partial_buffer_mongo():
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo():
values = [it.dict()["value"] for it in items]
assert "hello" in values
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
@@ -116,7 +123,9 @@ def test_persist_postgresql_local_db():
assert result is True
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager(
@@ -151,40 +160,42 @@ def test_persist_not_attempted_when_saver_disabled():
def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Verify data was persisted in mock
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": thread_id})
assert doc is not None
assert doc["messages"] == messages
# Simulate a message with existing thread
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
# Verify update worked
doc = collection.find_one({"thread_id": thread_id})
assert doc["messages"] == ["Another message."]
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="MongoDB server is not available"
)
def test_persist_mongodb_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
@@ -239,11 +250,11 @@ def test_unsupported_db_uri_scheme():
def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -263,7 +274,7 @@ def test_process_stream_with_interrupt_finish_reason():
)
is True
)
# Verify persistence occurred
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "int_test"})
@@ -395,7 +406,7 @@ def test_multiple_threads_isolation():
def test_mongodb_insert_and_update_paths():
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -404,7 +415,7 @@ def test_mongodb_insert_and_update_paths():
# Insert success (new thread)
assert manager._persist_to_mongodb("th1", ["message1"]) is True
# Verify insert worked
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "th1"})
@@ -413,7 +424,7 @@ def test_mongodb_insert_and_update_paths():
# Update success (existing thread)
assert manager._persist_to_mongodb("th1", ["message2"]) is True
# Verify update worked
doc = collection.find_one({"thread_id": "th1"})
assert doc["messages"] == ["message2"]
@@ -421,9 +432,9 @@ def test_mongodb_insert_and_update_paths():
# Test error case by mocking collection methods
original_find_one = collection.find_one
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
assert manager._persist_to_mongodb("th2", ["message"]) is False
# Restore original method
collection.find_one = original_find_one
@@ -591,19 +602,19 @@ def test_context_manager_calls_close(monkeypatch):
def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
# Success path with mongomock
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
# Failure path
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_mongo_client.side_effect = RuntimeError("Connection failed")
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None

View File

@@ -2,23 +2,21 @@
# SPDX-License-Identifier: MIT
import pytest
from langchain_core.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
FunctionMessageChunk,
ToolMessageChunk,
)
from src.llms import llm as llm_module
from langchain_core.messages import ChatMessageChunk
from src.llms.providers import dashscope as dashscope_module
from src.llms.providers.dashscope import (
ChatDashscope,
_convert_delta_to_message_chunk,
_convert_chunk_to_generation_chunk,
_convert_delta_to_message_chunk,
)

154
tests/unit/rag/test_dify.py Normal file
View File

@@ -0,0 +1,154 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from src.rag.dify import DifyProvider, parse_uri
# Dummy classes to mock dependencies
class DummyResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class DummyChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class DummyDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch imports in dify.py to use dummy classes
@pytest.fixture(autouse=True)
def patch_imports(monkeypatch):
import src.rag.dify as dify
dify.Resource = DummyResource
dify.Chunk = DummyChunk
dify.Document = DummyDocument
yield
def test_parse_uri_valid():
uri = "rag://dataset/123#abc"
dataset_id, document_id = parse_uri(uri)
assert dataset_id == "123"
assert document_id == "abc"
def test_parse_uri_invalid():
with pytest.raises(ValueError):
parse_uri("http://dataset/123#abc")
def test_init_env_vars(monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
assert provider.api_url == "http://api"
assert provider.api_key == "key"
def test_init_missing_env(monkeypatch):
monkeypatch.delenv("DIFY_API_URL", raising=False)
monkeypatch.setenv("DIFY_API_KEY", "key")
with pytest.raises(ValueError):
DifyProvider()
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.delenv("DIFY_API_KEY", raising=False)
with pytest.raises(ValueError):
DifyProvider()
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_success(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"records": [
{
"segment": {
"content": "chunk text",
"document": {
"id": "doc456",
"name": "Doc Title",
},
},
"score": 0.9,
}
]
}
mock_post.return_value = mock_response
docs = provider.query_relevant_documents("query", [resource])
assert len(docs) == 1
assert docs[0].id == "doc456"
assert docs[0].title == "Doc Title"
assert len(docs[0].chunks) == 1
assert docs[0].chunks[0].content == "chunk text"
assert docs[0].chunks[0].similarity == 0.9
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_error(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "error"
mock_post.return_value = mock_response
with pytest.raises(Exception):
provider.query_relevant_documents("query", [resource])
@patch("src.rag.dify.requests.get")
def test_list_resources_success(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "123", "name": "Dataset1", "description": "desc1"},
{"id": "456", "name": "Dataset2", "description": "desc2"},
]
}
mock_get.return_value = mock_response
resources = provider.list_resources()
assert len(resources) == 2
assert resources[0].uri == "rag://dataset/123"
assert resources[0].title == "Dataset1"
assert resources[0].description == "desc1"
assert resources[1].uri == "rag://dataset/456"
assert resources[1].title == "Dataset2"
assert resources[1].description == "desc2"
@patch("src.rag.dify.requests.get")
def test_list_resources_error(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "fail"
mock_get.return_value = mock_response
with pytest.raises(Exception):
provider.list_resources()

View File

@@ -2,9 +2,11 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
from uuid import uuid4
from types import SimpleNamespace
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
import pytest
import src.rag.milvus as milvus_mod
@@ -13,7 +15,6 @@ from src.rag.retriever import Resource
class DummyEmbedding:
def __init__(self, **kwargs):
self.kwargs = kwargs
@@ -369,9 +370,7 @@ def test_create_collection_lite(monkeypatch):
def list_collections(self): # noqa: D401
return [] # empty triggers creation
def create_collection(
self, collection_name, schema, index_params
): # noqa: D401
def create_collection(self, collection_name, schema, index_params): # noqa: D401
created["name"] = collection_name
created["schema"] = schema
created["index"] = index_params