mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-05-02 10:10:44 +08:00
feat: support dify in rag module (#550)
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,81 +1,85 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import psycopg
|
||||
import pytest
|
||||
|
||||
|
||||
class PostgreSQLMockInstance:
|
||||
"""Utility class for managing PostgreSQL mock instances."""
|
||||
|
||||
|
||||
def __init__(self, database_name: str = "test_db"):
|
||||
self.database_name = database_name
|
||||
self.temp_dir: Optional[Path] = None
|
||||
self.mock_connection: Optional[MagicMock] = None
|
||||
self.mock_data: Dict[str, Any] = {}
|
||||
self._setup_mock_data()
|
||||
|
||||
|
||||
def _setup_mock_data(self):
|
||||
"""Initialize mock data storage."""
|
||||
self.mock_data = {
|
||||
"chat_streams": {}, # thread_id -> record
|
||||
"table_exists": False,
|
||||
"connection_active": True
|
||||
"connection_active": True,
|
||||
}
|
||||
|
||||
|
||||
def connect(self) -> MagicMock:
|
||||
"""Create a mock PostgreSQL connection."""
|
||||
self.mock_connection = MagicMock()
|
||||
self._setup_mock_methods()
|
||||
return self.mock_connection
|
||||
|
||||
|
||||
def _setup_mock_methods(self):
|
||||
"""Setup mock methods for PostgreSQL operations."""
|
||||
if not self.mock_connection:
|
||||
return
|
||||
|
||||
|
||||
# Mock cursor context manager
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
|
||||
# Setup cursor operations
|
||||
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
|
||||
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
|
||||
mock_cursor.rowcount = 0
|
||||
|
||||
|
||||
# Setup connection operations
|
||||
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
|
||||
self.mock_connection.commit = MagicMock()
|
||||
self.mock_connection.rollback = MagicMock()
|
||||
self.mock_connection.close = MagicMock()
|
||||
|
||||
|
||||
# Store cursor for external access
|
||||
self._mock_cursor = mock_cursor
|
||||
|
||||
|
||||
def _mock_execute(self, sql: str, params=None):
|
||||
"""Mock SQL execution."""
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
|
||||
if "CREATE TABLE" in sql_upper:
|
||||
self.mock_data["table_exists"] = True
|
||||
self._mock_cursor.rowcount = 0
|
||||
|
||||
|
||||
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
|
||||
# Mock SELECT query
|
||||
if params and len(params) > 0:
|
||||
thread_id = params[0]
|
||||
if thread_id in self.mock_data["chat_streams"]:
|
||||
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id]
|
||||
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
|
||||
thread_id
|
||||
]
|
||||
else:
|
||||
self._mock_cursor._fetch_result = None
|
||||
else:
|
||||
self._mock_cursor._fetch_result = None
|
||||
|
||||
|
||||
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
|
||||
# Mock UPDATE query
|
||||
if params and len(params) >= 2:
|
||||
@@ -84,12 +88,12 @@ class PostgreSQLMockInstance:
|
||||
self.mock_data["chat_streams"][thread_id] = {
|
||||
"id": thread_id,
|
||||
"thread_id": thread_id,
|
||||
"messages": messages
|
||||
"messages": messages,
|
||||
}
|
||||
self._mock_cursor.rowcount = 1
|
||||
else:
|
||||
self._mock_cursor.rowcount = 0
|
||||
|
||||
|
||||
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
|
||||
# Mock INSERT query
|
||||
if params and len(params) >= 2:
|
||||
@@ -97,30 +101,30 @@ class PostgreSQLMockInstance:
|
||||
self.mock_data["chat_streams"][thread_id] = {
|
||||
"id": thread_id,
|
||||
"thread_id": thread_id,
|
||||
"messages": messages
|
||||
"messages": messages,
|
||||
}
|
||||
self._mock_cursor.rowcount = 1
|
||||
|
||||
|
||||
def _mock_fetchone(self):
|
||||
"""Mock fetchone operation."""
|
||||
return getattr(self._mock_cursor, '_fetch_result', None)
|
||||
|
||||
return getattr(self._mock_cursor, "_fetch_result", None)
|
||||
|
||||
def disconnect(self):
|
||||
"""Cleanup mock connection."""
|
||||
if self.mock_connection:
|
||||
self.mock_connection.close()
|
||||
self._setup_mock_data() # Reset data
|
||||
|
||||
|
||||
def reset_data(self):
|
||||
"""Reset all mock data."""
|
||||
self._setup_mock_data()
|
||||
|
||||
|
||||
def get_table_count(self, table_name: str) -> int:
|
||||
"""Get record count in a table."""
|
||||
if table_name == "chat_streams":
|
||||
return len(self.mock_data["chat_streams"])
|
||||
return 0
|
||||
|
||||
|
||||
def create_test_data(self, table_name: str, records: list):
|
||||
"""Insert test data into a table."""
|
||||
if table_name == "chat_streams":
|
||||
@@ -129,6 +133,7 @@ class PostgreSQLMockInstance:
|
||||
if thread_id:
|
||||
self.mock_data["chat_streams"][thread_id] = record
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_postgresql():
|
||||
"""Create a PostgreSQL mock instance."""
|
||||
@@ -137,6 +142,7 @@ def mock_postgresql():
|
||||
yield instance
|
||||
instance.disconnect()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_mock_postgresql():
|
||||
"""Create a clean PostgreSQL mock instance that resets between tests."""
|
||||
@@ -144,4 +150,4 @@ def clean_mock_postgresql():
|
||||
instance.connect()
|
||||
instance.reset_data()
|
||||
yield instance
|
||||
instance.disconnect()
|
||||
instance.disconnect()
|
||||
|
||||
@@ -2,15 +2,18 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mongomock
|
||||
from unittest.mock import patch, MagicMock
|
||||
import src.graph.checkpoint as checkpoint
|
||||
import pytest
|
||||
from postgres_mock_utils import PostgreSQLMockInstance
|
||||
|
||||
import src.graph.checkpoint as checkpoint
|
||||
|
||||
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
|
||||
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
|
||||
|
||||
|
||||
def has_real_db_connection():
|
||||
# Check the environment if the MongoDB server is available
|
||||
enabled = os.getenv("DB_TESTS_ENABLED", "false")
|
||||
@@ -18,27 +21,28 @@ def has_real_db_connection():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_with_local_postgres_db():
|
||||
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
|
||||
with patch('psycopg.connect') as mock_connect:
|
||||
with patch("psycopg.connect") as mock_connect:
|
||||
# Setup mock PostgreSQL connection
|
||||
pg_mock = PostgreSQLMockInstance()
|
||||
mock_connect.return_value = pg_mock.connect()
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=POSTGRES_URL,
|
||||
)
|
||||
)
|
||||
assert manager.postgres_conn is not None
|
||||
assert manager.mongo_client is None
|
||||
|
||||
|
||||
def test_with_local_mongo_db():
|
||||
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
@@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
|
||||
|
||||
def test_process_stream_partial_buffer_mongo():
|
||||
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
@@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo():
|
||||
values = [it.dict()["value"] for it in items]
|
||||
assert "hello" in values
|
||||
|
||||
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not has_real_db_connection(), reason="PostgreSQL Server is not available"
|
||||
)
|
||||
def test_persist_postgresql_local_db():
|
||||
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
@@ -116,7 +123,9 @@ def test_persist_postgresql_local_db():
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
|
||||
@pytest.mark.skipif(
|
||||
not has_real_db_connection(), reason="PostgreSQL Server is not available"
|
||||
)
|
||||
def test_persist_postgresql_called_with_aggregated_chunks():
|
||||
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
@@ -151,40 +160,42 @@ def test_persist_not_attempted_when_saver_disabled():
|
||||
|
||||
def test_persist_mongodb_local_db():
|
||||
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
)
|
||||
assert manager.mongo_db is not None
|
||||
assert manager.postgres_conn is None
|
||||
|
||||
|
||||
# Simulate a message to persist
|
||||
thread_id = "test_thread"
|
||||
messages = ["This is a test message."]
|
||||
result = manager._persist_to_mongodb(thread_id, messages)
|
||||
assert result is True
|
||||
|
||||
|
||||
# Verify data was persisted in mock
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
assert doc is not None
|
||||
assert doc["messages"] == messages
|
||||
|
||||
|
||||
# Simulate a message with existing thread
|
||||
result = manager._persist_to_mongodb(thread_id, ["Another message."])
|
||||
assert result is True
|
||||
|
||||
|
||||
# Verify update worked
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
assert doc["messages"] == ["Another message."]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
|
||||
@pytest.mark.skipif(
|
||||
not has_real_db_connection(), reason="MongoDB server is not available"
|
||||
)
|
||||
def test_persist_mongodb_called_with_aggregated_chunks():
|
||||
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
|
||||
|
||||
@@ -239,11 +250,11 @@ def test_unsupported_db_uri_scheme():
|
||||
|
||||
def test_process_stream_with_interrupt_finish_reason():
|
||||
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
|
||||
manager = checkpoint.ChatStreamManager(
|
||||
checkpoint_saver=True,
|
||||
db_uri=MONGO_URL,
|
||||
@@ -263,7 +274,7 @@ def test_process_stream_with_interrupt_finish_reason():
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# Verify persistence occurred
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": "int_test"})
|
||||
@@ -395,7 +406,7 @@ def test_multiple_threads_isolation():
|
||||
|
||||
def test_mongodb_insert_and_update_paths():
|
||||
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
# Setup mongomock
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
@@ -404,7 +415,7 @@ def test_mongodb_insert_and_update_paths():
|
||||
|
||||
# Insert success (new thread)
|
||||
assert manager._persist_to_mongodb("th1", ["message1"]) is True
|
||||
|
||||
|
||||
# Verify insert worked
|
||||
collection = manager.mongo_db.chat_streams
|
||||
doc = collection.find_one({"thread_id": "th1"})
|
||||
@@ -413,7 +424,7 @@ def test_mongodb_insert_and_update_paths():
|
||||
|
||||
# Update success (existing thread)
|
||||
assert manager._persist_to_mongodb("th1", ["message2"]) is True
|
||||
|
||||
|
||||
# Verify update worked
|
||||
doc = collection.find_one({"thread_id": "th1"})
|
||||
assert doc["messages"] == ["message2"]
|
||||
@@ -421,9 +432,9 @@ def test_mongodb_insert_and_update_paths():
|
||||
# Test error case by mocking collection methods
|
||||
original_find_one = collection.find_one
|
||||
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
|
||||
|
||||
|
||||
assert manager._persist_to_mongodb("th2", ["message"]) is False
|
||||
|
||||
|
||||
# Restore original method
|
||||
collection.find_one = original_find_one
|
||||
|
||||
@@ -591,19 +602,19 @@ def test_context_manager_calls_close(monkeypatch):
|
||||
|
||||
def test_init_mongodb_success_and_failure(monkeypatch):
|
||||
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
|
||||
|
||||
|
||||
# Success path with mongomock
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
mock_client = mongomock.MongoClient()
|
||||
mock_mongo_client.return_value = mock_client
|
||||
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
assert manager.mongo_db is not None
|
||||
|
||||
# Failure path
|
||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
||||
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||
mock_mongo_client.side_effect = RuntimeError("Connection failed")
|
||||
|
||||
|
||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||
# Should have no mongo_db set on failure
|
||||
assert getattr(manager, "mongo_db", None) is None
|
||||
|
||||
@@ -2,23 +2,21 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
ChatMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
HumanMessageChunk,
|
||||
SystemMessageChunk,
|
||||
FunctionMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
|
||||
from src.llms import llm as llm_module
|
||||
from langchain_core.messages import ChatMessageChunk
|
||||
from src.llms.providers import dashscope as dashscope_module
|
||||
|
||||
from src.llms.providers.dashscope import (
|
||||
ChatDashscope,
|
||||
_convert_delta_to_message_chunk,
|
||||
_convert_chunk_to_generation_chunk,
|
||||
_convert_delta_to_message_chunk,
|
||||
)
|
||||
|
||||
|
||||
|
||||
154
tests/unit/rag/test_dify.py
Normal file
154
tests/unit/rag/test_dify.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.rag.dify import DifyProvider, parse_uri
|
||||
|
||||
|
||||
# Dummy classes to mock dependencies
|
||||
class DummyResource:
|
||||
def __init__(self, uri, title="", description=""):
|
||||
self.uri = uri
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, content, similarity):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class DummyDocument:
|
||||
def __init__(self, id, title, chunks=None):
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.chunks = chunks or []
|
||||
|
||||
|
||||
# Patch imports in dify.py to use dummy classes
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_imports(monkeypatch):
|
||||
import src.rag.dify as dify
|
||||
|
||||
dify.Resource = DummyResource
|
||||
dify.Chunk = DummyChunk
|
||||
dify.Document = DummyDocument
|
||||
yield
|
||||
|
||||
|
||||
def test_parse_uri_valid():
|
||||
uri = "rag://dataset/123#abc"
|
||||
dataset_id, document_id = parse_uri(uri)
|
||||
assert dataset_id == "123"
|
||||
assert document_id == "abc"
|
||||
|
||||
|
||||
def test_parse_uri_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
parse_uri("http://dataset/123#abc")
|
||||
|
||||
|
||||
def test_init_env_vars(monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
assert provider.api_url == "http://api"
|
||||
assert provider.api_key == "key"
|
||||
|
||||
|
||||
def test_init_missing_env(monkeypatch):
|
||||
monkeypatch.delenv("DIFY_API_URL", raising=False)
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
with pytest.raises(ValueError):
|
||||
DifyProvider()
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.delenv("DIFY_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
DifyProvider()
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.post")
|
||||
def test_query_relevant_documents_success(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
resource = DummyResource("rag://dataset/123#doc456")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"records": [
|
||||
{
|
||||
"segment": {
|
||||
"content": "chunk text",
|
||||
"document": {
|
||||
"id": "doc456",
|
||||
"name": "Doc Title",
|
||||
},
|
||||
},
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
docs = provider.query_relevant_documents("query", [resource])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "doc456"
|
||||
assert docs[0].title == "Doc Title"
|
||||
assert len(docs[0].chunks) == 1
|
||||
assert docs[0].chunks[0].content == "chunk text"
|
||||
assert docs[0].chunks[0].similarity == 0.9
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.post")
|
||||
def test_query_relevant_documents_error(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
resource = DummyResource("rag://dataset/123#doc456")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "error"
|
||||
mock_post.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.query_relevant_documents("query", [resource])
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.get")
|
||||
def test_list_resources_success(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) == 2
|
||||
assert resources[0].uri == "rag://dataset/123"
|
||||
assert resources[0].title == "Dataset1"
|
||||
assert resources[0].description == "desc1"
|
||||
assert resources[1].uri == "rag://dataset/456"
|
||||
assert resources[1].title == "Dataset2"
|
||||
assert resources[1].description == "desc2"
|
||||
|
||||
|
||||
@patch("src.rag.dify.requests.get")
|
||||
def test_list_resources_error(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||
provider = DifyProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "fail"
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.list_resources()
|
||||
@@ -2,9 +2,11 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
from uuid import uuid4
|
||||
from types import SimpleNamespace
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import src.rag.milvus as milvus_mod
|
||||
@@ -13,7 +15,6 @@ from src.rag.retriever import Resource
|
||||
|
||||
|
||||
class DummyEmbedding:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
|
||||
@@ -369,9 +370,7 @@ def test_create_collection_lite(monkeypatch):
|
||||
def list_collections(self): # noqa: D401
|
||||
return [] # empty triggers creation
|
||||
|
||||
def create_collection(
|
||||
self, collection_name, schema, index_params
|
||||
): # noqa: D401
|
||||
def create_collection(self, collection_name, schema, index_params): # noqa: D401
|
||||
created["name"] = collection_name
|
||||
created["schema"] = schema
|
||||
created["index"] = index_params
|
||||
|
||||
Reference in New Issue
Block a user