feat: support dify in rag module (#550)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Chayton Bai
2025-09-16 20:30:45 +08:00
committed by GitHub
parent 5085bf8ee9
commit 7694bb5d72
19 changed files with 407 additions and 87 deletions

View File

@@ -41,13 +41,18 @@ TAVILY_API_KEY=tvly-xxx
# RAGFLOW_RETRIEVAL_SIZE=10
# RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean # Optional. To use RAGFlow's cross-language search, please separate each language with a single comma
# RAG_PROVIDER=dify
# DIFY_API_URL="https://api.dify.ai/v1"
# DIFY_API_KEY="dataset-xxx"
# MOI is a hybrid database that mainly serves enterprise users (https://www.matrixorigin.io/matrixone-intelligence)
# RAG_PROVIDER=moi
# MOI_API_URL="https://freetier-01.cn-hangzhou.cluster.matrixonecloud.cn"
# MOI_API_URL="https://cluster.matrixonecloud.cn"
# MOI_API_KEY="xxx-xxx-xxx-xxx"
# MOI_RETRIEVAL_SIZE=10
# MOI_LIST_LIMIT=10
# RAG_PROVIDER: milvus (using free milvus instance on zilliz cloud: https://docs.zilliz.com/docs/quick-start )
# RAG_PROVIDER=milvus
# MILVUS_URI=<endpoint_of_self_hosted_milvus_or_zilliz_cloud>

View File

@@ -4,10 +4,11 @@
"""
Server script for running the DeerFlow API.
"""
import os
import asyncio
import argparse
import asyncio
import logging
import os
import signal
import sys

View File

@@ -8,9 +8,9 @@ from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.config.report_style import ReportStyle
from src.rag.retriever import Resource
from src.config.loader import get_str_env, get_int_env, get_bool_env
logger = logging.getLogger(__name__)

View File

@@ -22,6 +22,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
class RAGProvider(enum.Enum):
DIFY = "dify"
RAGFLOW = "ragflow"
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
MOI = "moi"

View File

@@ -6,10 +6,12 @@ import logging
import uuid
from datetime import datetime
from typing import List, Optional, Tuple
import psycopg
from langgraph.store.memory import InMemoryStore
from psycopg.rows import dict_row
from pymongo import MongoClient
from langgraph.store.memory import InMemoryStore
from src.config.loader import get_bool_env, get_str_env

View File

@@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from typing import get_args
from src.config import load_yaml_config
from src.config.agents import LLMType

View File

@@ -211,7 +211,6 @@ class ChatDashscope(ChatOpenAI):
and hasattr(response.choices[0], "message")
and hasattr(response.choices[0].message, "reasoning_content")
):
reasoning_content = response.choices[0].message.reasoning_content
if reasoning_content and chat_result.generations:
chat_result.generations[0].message.additional_kwargs[

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT
from .builder import build_retriever
from .dify import DifyProvider
from .ragflow import RAGFlowProvider
from .moi import MOIProvider
from .retriever import Chunk, Document, Resource, Retriever
@@ -11,6 +12,7 @@ __all__ = [
Retriever,
Document,
Resource,
DifyProvider,
RAGFlowProvider,
MOIProvider,
VikingDBKnowledgeBaseProvider,

View File

@@ -2,14 +2,17 @@
# SPDX-License-Identifier: MIT
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
from src.rag.dify import DifyProvider
from src.rag.milvus import MilvusProvider
from src.rag.ragflow import RAGFlowProvider
from src.rag.moi import MOIProvider
from src.rag.retriever import Retriever
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
from src.rag.milvus import MilvusProvider
def build_retriever() -> Retriever | None:
if SELECTED_RAG_PROVIDER == RAGProvider.DIFY.value:
return DifyProvider()
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
return RAGFlowProvider()
elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value:

132
src/rag/dify.py Normal file
View File

@@ -0,0 +1,132 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
from urllib.parse import urlparse
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
class DifyProvider(Retriever):
"""
DifyProvider is a provider that uses dify to retrieve documents.
"""
api_url: str
api_key: str
def __init__(self):
api_url = os.getenv("DIFY_API_URL")
if not api_url:
raise ValueError("DIFY_API_URL is not set")
self.api_url = api_url
api_key = os.getenv("DIFY_API_KEY")
if not api_key:
raise ValueError("DIFY_API_KEY is not set")
self.api_key = api_key
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
if not resources:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
all_documents = {}
for resource in resources:
dataset_id, _ = parse_uri(resource.uri)
payload = {
"query": query,
"retrieval_model": {
"search_method": "hybrid_search",
"reranking_enable": False,
"weights": {
"weight_type": "customized",
"keyword_setting": {"keyword_weight": 0.3},
"vector_setting": {"vector_weight": 0.7},
},
"top_k": 3,
"score_threshold_enabled": True,
"score_threshold": 0.5,
},
}
response = requests.post(
f"{self.api_url}/datasets/{dataset_id}/retrieve",
headers=headers,
json=payload,
)
if response.status_code != 200:
raise Exception(f"Failed to query documents: {response.text}")
result = response.json()
records = result.get("records", {})
for record in records:
segment = record.get("segment")
if not segment:
continue
document_info = segment.get("document")
if not document_info:
continue
doc_id = document_info.get("id")
doc_name = document_info.get("name")
if not doc_id or not doc_name:
continue
if doc_id not in all_documents:
all_documents[doc_id] = Document(
id=doc_id, title=doc_name, chunks=[]
)
chunk = Chunk(
content=segment.get("content", ""),
similarity=record.get("score", 0.0),
)
all_documents[doc_id].chunks.append(chunk)
return list(all_documents.values())
def list_resources(self, query: str | None = None) -> list[Resource]:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
params = {}
if query:
params["keyword"] = query
response = requests.get(
f"{self.api_url}/datasets", headers=headers, params=params
)
if response.status_code != 200:
raise Exception(f"Failed to list resources: {response.text}")
result = response.json()
resources = []
for item in result.get("data", []):
item = Resource(
uri=f"rag://dataset/{item.get('id')}",
title=item.get("name", ""),
description=item.get("description", ""),
)
resources.append(item)
return resources
def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment

View File

@@ -7,11 +7,12 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
from langchain_milvus.vectorstores import Milvus as LangchainMilvus
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
from langchain_openai import OpenAIEmbeddings
from openai import OpenAI
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.rag.retriever import Chunk, Document, Resource, Retriever
from src.config.loader import get_bool_env, get_str_env, get_int_env
logger = logging.getLogger(__name__)
@@ -466,7 +467,7 @@ class MilvusRetriever(Retriever):
resources.append(
Resource(
uri=r.get(self.url_field, "")
or f"milvus://{r.get(self.id_field,'')}",
or f"milvus://{r.get(self.id_field, '')}",
title=r.get(self.title_field, "")
or r.get(self.id_field, "Unnamed"),
description="Stored Milvus document",
@@ -476,21 +477,23 @@ class MilvusRetriever(Retriever):
# Use similarity_search_by_vector for lightweight listing.
# If a query is provided embed it; else use a zero vector.
docs: Iterable[Any] = self.client.similarity_search(
query, k=100, expr="source == 'examples'" # Limit to 100 results
query,
k=100,
expr="source == 'examples'", # Limit to 100 results
)
for d in docs:
meta = getattr(d, "metadata", {}) or {}
# check if the resource is in the list of resources
if resources and any(
r.uri == meta.get(self.url_field, "")
or r.uri == f"milvus://{meta.get(self.id_field,'')}"
or r.uri == f"milvus://{meta.get(self.id_field, '')}"
for r in resources
):
continue
resources.append(
Resource(
uri=meta.get(self.url_field, "")
or f"milvus://{meta.get(self.id_field,'')}",
or f"milvus://{meta.get(self.id_field, '')}",
title=meta.get(self.title_field, "")
or meta.get(self.id_field, "Unnamed"),
description="Stored Milvus document",

View File

@@ -11,10 +11,10 @@ from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
from langgraph.types import Command
from langgraph.store.memory import InMemoryStore
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.memory import InMemoryStore
from langgraph.types import Command
from psycopg_pool import AsyncConnectionPool
from src.config.configuration import get_recursion_limit
@@ -22,6 +22,7 @@ from src.config.loader import get_bool_env, get_str_env
from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
from src.graph.checkpoint import chat_stream_message
from src.llms.llm import get_configured_llm_models
from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph
@@ -47,7 +48,6 @@ from src.server.rag_request import (
RAGResourcesResponse,
)
from src.tools import VolcengineTTS
from src.graph.checkpoint import chat_stream_message
from src.utils.json_utils import sanitize_args
logger = logging.getLogger(__name__)

View File

@@ -97,7 +97,8 @@ async def load_mcp_tools(
)
return await _get_tools_from_client_session(
sse_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds
sse_client(url=url, headers=headers, timeout=timeout_seconds),
timeout_seconds,
)
elif server_type == "streamable_http":
@@ -107,7 +108,10 @@ async def load_mcp_tools(
)
return await _get_tools_from_client_session(
streamablehttp_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds,
streamablehttp_client(
url=url, headers=headers, timeout=timeout_seconds
),
timeout_seconds,
)
else:

View File

@@ -4,6 +4,7 @@
import json
import logging
from typing import Any
import json_repair
logger = logging.getLogger(__name__)

View File

@@ -1,81 +1,85 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
import tempfile
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, patch
from typing import Dict, Any, Optional
import psycopg
import pytest
class PostgreSQLMockInstance:
"""Utility class for managing PostgreSQL mock instances."""
def __init__(self, database_name: str = "test_db"):
self.database_name = database_name
self.temp_dir: Optional[Path] = None
self.mock_connection: Optional[MagicMock] = None
self.mock_data: Dict[str, Any] = {}
self._setup_mock_data()
def _setup_mock_data(self):
"""Initialize mock data storage."""
self.mock_data = {
"chat_streams": {}, # thread_id -> record
"table_exists": False,
"connection_active": True
"connection_active": True,
}
def connect(self) -> MagicMock:
"""Create a mock PostgreSQL connection."""
self.mock_connection = MagicMock()
self._setup_mock_methods()
return self.mock_connection
def _setup_mock_methods(self):
"""Setup mock methods for PostgreSQL operations."""
if not self.mock_connection:
return
# Mock cursor context manager
mock_cursor = MagicMock()
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
mock_cursor.__exit__ = MagicMock(return_value=False)
# Setup cursor operations
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
mock_cursor.rowcount = 0
# Setup connection operations
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
self.mock_connection.commit = MagicMock()
self.mock_connection.rollback = MagicMock()
self.mock_connection.close = MagicMock()
# Store cursor for external access
self._mock_cursor = mock_cursor
def _mock_execute(self, sql: str, params=None):
"""Mock SQL execution."""
sql_upper = sql.upper().strip()
if "CREATE TABLE" in sql_upper:
self.mock_data["table_exists"] = True
self._mock_cursor.rowcount = 0
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
# Mock SELECT query
if params and len(params) > 0:
thread_id = params[0]
if thread_id in self.mock_data["chat_streams"]:
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id]
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
thread_id
]
else:
self._mock_cursor._fetch_result = None
else:
self._mock_cursor._fetch_result = None
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
# Mock UPDATE query
if params and len(params) >= 2:
@@ -84,12 +88,12 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
"messages": messages,
}
self._mock_cursor.rowcount = 1
else:
self._mock_cursor.rowcount = 0
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
# Mock INSERT query
if params and len(params) >= 2:
@@ -97,30 +101,30 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = {
"id": thread_id,
"thread_id": thread_id,
"messages": messages
"messages": messages,
}
self._mock_cursor.rowcount = 1
def _mock_fetchone(self):
"""Mock fetchone operation."""
return getattr(self._mock_cursor, '_fetch_result', None)
return getattr(self._mock_cursor, "_fetch_result", None)
def disconnect(self):
"""Cleanup mock connection."""
if self.mock_connection:
self.mock_connection.close()
self._setup_mock_data() # Reset data
def reset_data(self):
"""Reset all mock data."""
self._setup_mock_data()
def get_table_count(self, table_name: str) -> int:
"""Get record count in a table."""
if table_name == "chat_streams":
return len(self.mock_data["chat_streams"])
return 0
def create_test_data(self, table_name: str, records: list):
"""Insert test data into a table."""
if table_name == "chat_streams":
@@ -129,6 +133,7 @@ class PostgreSQLMockInstance:
if thread_id:
self.mock_data["chat_streams"][thread_id] = record
@pytest.fixture
def mock_postgresql():
"""Create a PostgreSQL mock instance."""
@@ -137,6 +142,7 @@ def mock_postgresql():
yield instance
instance.disconnect()
@pytest.fixture
def clean_mock_postgresql():
"""Create a clean PostgreSQL mock instance that resets between tests."""
@@ -144,4 +150,4 @@ def clean_mock_postgresql():
instance.connect()
instance.reset_data()
yield instance
instance.disconnect()
instance.disconnect()

View File

@@ -2,15 +2,18 @@
# SPDX-License-Identifier: MIT
import os
import pytest
from unittest.mock import MagicMock, patch
import mongomock
from unittest.mock import patch, MagicMock
import src.graph.checkpoint as checkpoint
import pytest
from postgres_mock_utils import PostgreSQLMockInstance
import src.graph.checkpoint as checkpoint
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def has_real_db_connection():
# Check the environment if the MongoDB server is available
enabled = os.getenv("DB_TESTS_ENABLED", "false")
@@ -18,27 +21,28 @@ def has_real_db_connection():
return True
return False
def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
with patch('psycopg.connect') as mock_connect:
with patch("psycopg.connect") as mock_connect:
# Setup mock PostgreSQL connection
pg_mock = PostgreSQLMockInstance()
mock_connect.return_value = pg_mock.connect()
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=POSTGRES_URL,
)
)
assert manager.postgres_conn is not None
assert manager.mongo_client is None
def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
def test_process_stream_partial_buffer_mongo():
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo():
values = [it.dict()["value"] for it in items]
assert "hello" in values
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager(
@@ -116,7 +123,9 @@ def test_persist_postgresql_local_db():
assert result is True
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager(
@@ -151,40 +160,42 @@ def test_persist_not_attempted_when_saver_disabled():
def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
)
assert manager.mongo_db is not None
assert manager.postgres_conn is None
# Simulate a message to persist
thread_id = "test_thread"
messages = ["This is a test message."]
result = manager._persist_to_mongodb(thread_id, messages)
assert result is True
# Verify data was persisted in mock
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": thread_id})
assert doc is not None
assert doc["messages"] == messages
# Simulate a message with existing thread
result = manager._persist_to_mongodb(thread_id, ["Another message."])
assert result is True
# Verify update worked
doc = collection.find_one({"thread_id": thread_id})
assert doc["messages"] == ["Another message."]
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="MongoDB server is not available"
)
def test_persist_mongodb_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
@@ -239,11 +250,11 @@ def test_unsupported_db_uri_scheme():
def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(
checkpoint_saver=True,
db_uri=MONGO_URL,
@@ -263,7 +274,7 @@ def test_process_stream_with_interrupt_finish_reason():
)
is True
)
# Verify persistence occurred
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "int_test"})
@@ -395,7 +406,7 @@ def test_multiple_threads_isolation():
def test_mongodb_insert_and_update_paths():
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
@@ -404,7 +415,7 @@ def test_mongodb_insert_and_update_paths():
# Insert success (new thread)
assert manager._persist_to_mongodb("th1", ["message1"]) is True
# Verify insert worked
collection = manager.mongo_db.chat_streams
doc = collection.find_one({"thread_id": "th1"})
@@ -413,7 +424,7 @@ def test_mongodb_insert_and_update_paths():
# Update success (existing thread)
assert manager._persist_to_mongodb("th1", ["message2"]) is True
# Verify update worked
doc = collection.find_one({"thread_id": "th1"})
assert doc["messages"] == ["message2"]
@@ -421,9 +432,9 @@ def test_mongodb_insert_and_update_paths():
# Test error case by mocking collection methods
original_find_one = collection.find_one
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
assert manager._persist_to_mongodb("th2", ["message"]) is False
# Restore original method
collection.find_one = original_find_one
@@ -591,19 +602,19 @@ def test_context_manager_calls_close(monkeypatch):
def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
# Success path with mongomock
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
assert manager.mongo_db is not None
# Failure path
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_mongo_client.side_effect = RuntimeError("Connection failed")
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
# Should have no mongo_db set on failure
assert getattr(manager, "mongo_db", None) is None

View File

@@ -2,23 +2,21 @@
# SPDX-License-Identifier: MIT
import pytest
from langchain_core.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
FunctionMessageChunk,
ToolMessageChunk,
)
from src.llms import llm as llm_module
from langchain_core.messages import ChatMessageChunk
from src.llms.providers import dashscope as dashscope_module
from src.llms.providers.dashscope import (
ChatDashscope,
_convert_delta_to_message_chunk,
_convert_chunk_to_generation_chunk,
_convert_delta_to_message_chunk,
)

154
tests/unit/rag/test_dify.py Normal file
View File

@@ -0,0 +1,154 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from src.rag.dify import DifyProvider, parse_uri
# Dummy classes to mock dependencies
class DummyResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class DummyChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class DummyDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch imports in dify.py to use dummy classes
@pytest.fixture(autouse=True)
def patch_imports(monkeypatch):
import src.rag.dify as dify
dify.Resource = DummyResource
dify.Chunk = DummyChunk
dify.Document = DummyDocument
yield
def test_parse_uri_valid():
uri = "rag://dataset/123#abc"
dataset_id, document_id = parse_uri(uri)
assert dataset_id == "123"
assert document_id == "abc"
def test_parse_uri_invalid():
with pytest.raises(ValueError):
parse_uri("http://dataset/123#abc")
def test_init_env_vars(monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
assert provider.api_url == "http://api"
assert provider.api_key == "key"
def test_init_missing_env(monkeypatch):
monkeypatch.delenv("DIFY_API_URL", raising=False)
monkeypatch.setenv("DIFY_API_KEY", "key")
with pytest.raises(ValueError):
DifyProvider()
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.delenv("DIFY_API_KEY", raising=False)
with pytest.raises(ValueError):
DifyProvider()
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_success(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"records": [
{
"segment": {
"content": "chunk text",
"document": {
"id": "doc456",
"name": "Doc Title",
},
},
"score": 0.9,
}
]
}
mock_post.return_value = mock_response
docs = provider.query_relevant_documents("query", [resource])
assert len(docs) == 1
assert docs[0].id == "doc456"
assert docs[0].title == "Doc Title"
assert len(docs[0].chunks) == 1
assert docs[0].chunks[0].content == "chunk text"
assert docs[0].chunks[0].similarity == 0.9
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_error(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "error"
mock_post.return_value = mock_response
with pytest.raises(Exception):
provider.query_relevant_documents("query", [resource])
@patch("src.rag.dify.requests.get")
def test_list_resources_success(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "123", "name": "Dataset1", "description": "desc1"},
{"id": "456", "name": "Dataset2", "description": "desc2"},
]
}
mock_get.return_value = mock_response
resources = provider.list_resources()
assert len(resources) == 2
assert resources[0].uri == "rag://dataset/123"
assert resources[0].title == "Dataset1"
assert resources[0].description == "desc1"
assert resources[1].uri == "rag://dataset/456"
assert resources[1].title == "Dataset2"
assert resources[1].description == "desc2"
@patch("src.rag.dify.requests.get")
def test_list_resources_error(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "fail"
mock_get.return_value = mock_response
with pytest.raises(Exception):
provider.list_resources()

View File

@@ -2,9 +2,11 @@
# SPDX-License-Identifier: MIT
from __future__ import annotations
from uuid import uuid4
from types import SimpleNamespace
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
import pytest
import src.rag.milvus as milvus_mod
@@ -13,7 +15,6 @@ from src.rag.retriever import Resource
class DummyEmbedding:
def __init__(self, **kwargs):
self.kwargs = kwargs
@@ -369,9 +370,7 @@ def test_create_collection_lite(monkeypatch):
def list_collections(self): # noqa: D401
return [] # empty triggers creation
def create_collection(
self, collection_name, schema, index_params
): # noqa: D401
def create_collection(self, collection_name, schema, index_params): # noqa: D401
created["name"] = collection_name
created["schema"] = schema
created["index"] = index_params