feat: support dify in rag module (#550)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Chayton Bai
2025-09-16 20:30:45 +08:00
committed by GitHub
parent 5085bf8ee9
commit 7694bb5d72
19 changed files with 407 additions and 87 deletions

View File

@@ -41,13 +41,18 @@ TAVILY_API_KEY=tvly-xxx
# RAGFLOW_RETRIEVAL_SIZE=10 # RAGFLOW_RETRIEVAL_SIZE=10
# RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean # Optional. To use RAGFlow's cross-language search, please separate each language with a single comma # RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean # Optional. To use RAGFlow's cross-language search, please separate each language with a single comma
# RAG_PROVIDER=dify
# DIFY_API_URL="https://api.dify.ai/v1"
# DIFY_API_KEY="dataset-xxx"
# MOI is a hybrid database that mainly serves enterprise users (https://www.matrixorigin.io/matrixone-intelligence) # MOI is a hybrid database that mainly serves enterprise users (https://www.matrixorigin.io/matrixone-intelligence)
# RAG_PROVIDER=moi # RAG_PROVIDER=moi
# MOI_API_URL="https://freetier-01.cn-hangzhou.cluster.matrixonecloud.cn" # MOI_API_URL="https://cluster.matrixonecloud.cn"
# MOI_API_KEY="xxx-xxx-xxx-xxx" # MOI_API_KEY="xxx-xxx-xxx-xxx"
# MOI_RETRIEVAL_SIZE=10 # MOI_RETRIEVAL_SIZE=10
# MOI_LIST_LIMIT=10 # MOI_LIST_LIMIT=10
# RAG_PROVIDER: milvus (using free milvus instance on zilliz cloud: https://docs.zilliz.com/docs/quick-start ) # RAG_PROVIDER: milvus (using free milvus instance on zilliz cloud: https://docs.zilliz.com/docs/quick-start )
# RAG_PROVIDER=milvus # RAG_PROVIDER=milvus
# MILVUS_URI=<endpoint_of_self_hosted_milvus_or_zilliz_cloud> # MILVUS_URI=<endpoint_of_self_hosted_milvus_or_zilliz_cloud>

View File

@@ -4,10 +4,11 @@
""" """
Server script for running the DeerFlow API. Server script for running the DeerFlow API.
""" """
import os
import asyncio
import argparse import argparse
import asyncio
import logging import logging
import os
import signal import signal
import sys import sys

View File

@@ -8,9 +8,9 @@ from typing import Any, Optional
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.config.report_style import ReportStyle from src.config.report_style import ReportStyle
from src.rag.retriever import Resource from src.rag.retriever import Resource
from src.config.loader import get_str_env, get_int_env, get_bool_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -22,6 +22,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
class RAGProvider(enum.Enum): class RAGProvider(enum.Enum):
DIFY = "dify"
RAGFLOW = "ragflow" RAGFLOW = "ragflow"
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base" VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
MOI = "moi" MOI = "moi"

View File

@@ -6,10 +6,12 @@ import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import psycopg import psycopg
from langgraph.store.memory import InMemoryStore
from psycopg.rows import dict_row from psycopg.rows import dict_row
from pymongo import MongoClient from pymongo import MongoClient
from langgraph.store.memory import InMemoryStore
from src.config.loader import get_bool_env, get_str_env from src.config.loader import get_bool_env, get_str_env

View File

@@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_deepseek import ChatDeepSeek from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI from langchain_openai import AzureChatOpenAI, ChatOpenAI
from typing import get_args
from src.config import load_yaml_config from src.config import load_yaml_config
from src.config.agents import LLMType from src.config.agents import LLMType

View File

@@ -211,7 +211,6 @@ class ChatDashscope(ChatOpenAI):
and hasattr(response.choices[0], "message") and hasattr(response.choices[0], "message")
and hasattr(response.choices[0].message, "reasoning_content") and hasattr(response.choices[0].message, "reasoning_content")
): ):
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content
if reasoning_content and chat_result.generations: if reasoning_content and chat_result.generations:
chat_result.generations[0].message.additional_kwargs[ chat_result.generations[0].message.additional_kwargs[

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from .builder import build_retriever from .builder import build_retriever
from .dify import DifyProvider
from .ragflow import RAGFlowProvider from .ragflow import RAGFlowProvider
from .moi import MOIProvider from .moi import MOIProvider
from .retriever import Chunk, Document, Resource, Retriever from .retriever import Chunk, Document, Resource, Retriever
@@ -11,6 +12,7 @@ __all__ = [
Retriever, Retriever,
Document, Document,
Resource, Resource,
DifyProvider,
RAGFlowProvider, RAGFlowProvider,
MOIProvider, MOIProvider,
VikingDBKnowledgeBaseProvider, VikingDBKnowledgeBaseProvider,

View File

@@ -2,14 +2,17 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
from src.rag.dify import DifyProvider
from src.rag.milvus import MilvusProvider
from src.rag.ragflow import RAGFlowProvider from src.rag.ragflow import RAGFlowProvider
from src.rag.moi import MOIProvider from src.rag.moi import MOIProvider
from src.rag.retriever import Retriever from src.rag.retriever import Retriever
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
from src.rag.milvus import MilvusProvider
def build_retriever() -> Retriever | None: def build_retriever() -> Retriever | None:
if SELECTED_RAG_PROVIDER == RAGProvider.DIFY.value:
return DifyProvider()
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value: if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
return RAGFlowProvider() return RAGFlowProvider()
elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value: elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value:

132
src/rag/dify.py Normal file
View File

@@ -0,0 +1,132 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
from urllib.parse import urlparse
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
class DifyProvider(Retriever):
"""
DifyProvider is a provider that uses dify to retrieve documents.
"""
api_url: str
api_key: str
def __init__(self):
api_url = os.getenv("DIFY_API_URL")
if not api_url:
raise ValueError("DIFY_API_URL is not set")
self.api_url = api_url
api_key = os.getenv("DIFY_API_KEY")
if not api_key:
raise ValueError("DIFY_API_KEY is not set")
self.api_key = api_key
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
if not resources:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
all_documents = {}
for resource in resources:
dataset_id, _ = parse_uri(resource.uri)
payload = {
"query": query,
"retrieval_model": {
"search_method": "hybrid_search",
"reranking_enable": False,
"weights": {
"weight_type": "customized",
"keyword_setting": {"keyword_weight": 0.3},
"vector_setting": {"vector_weight": 0.7},
},
"top_k": 3,
"score_threshold_enabled": True,
"score_threshold": 0.5,
},
}
response = requests.post(
f"{self.api_url}/datasets/{dataset_id}/retrieve",
headers=headers,
json=payload,
)
if response.status_code != 200:
raise Exception(f"Failed to query documents: {response.text}")
result = response.json()
records = result.get("records", {})
for record in records:
segment = record.get("segment")
if not segment:
continue
document_info = segment.get("document")
if not document_info:
continue
doc_id = document_info.get("id")
doc_name = document_info.get("name")
if not doc_id or not doc_name:
continue
if doc_id not in all_documents:
all_documents[doc_id] = Document(
id=doc_id, title=doc_name, chunks=[]
)
chunk = Chunk(
content=segment.get("content", ""),
similarity=record.get("score", 0.0),
)
all_documents[doc_id].chunks.append(chunk)
return list(all_documents.values())
def list_resources(self, query: str | None = None) -> list[Resource]:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
params = {}
if query:
params["keyword"] = query
response = requests.get(
f"{self.api_url}/datasets", headers=headers, params=params
)
if response.status_code != 200:
raise Exception(f"Failed to list resources: {response.text}")
result = response.json()
resources = []
for item in result.get("data", []):
item = Resource(
uri=f"rag://dataset/{item.get('id')}",
title=item.get("name", ""),
description=item.get("description", ""),
)
resources.append(item)
return resources
def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment

View File

@@ -7,11 +7,12 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
from langchain_milvus.vectorstores import Milvus as LangchainMilvus from langchain_milvus.vectorstores import Milvus as LangchainMilvus
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
from langchain_openai import OpenAIEmbeddings from langchain_openai import OpenAIEmbeddings
from openai import OpenAI from openai import OpenAI
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.rag.retriever import Chunk, Document, Resource, Retriever from src.rag.retriever import Chunk, Document, Resource, Retriever
from src.config.loader import get_bool_env, get_str_env, get_int_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -466,7 +467,7 @@ class MilvusRetriever(Retriever):
resources.append( resources.append(
Resource( Resource(
uri=r.get(self.url_field, "") uri=r.get(self.url_field, "")
or f"milvus://{r.get(self.id_field,'')}", or f"milvus://{r.get(self.id_field, '')}",
title=r.get(self.title_field, "") title=r.get(self.title_field, "")
or r.get(self.id_field, "Unnamed"), or r.get(self.id_field, "Unnamed"),
description="Stored Milvus document", description="Stored Milvus document",
@@ -476,21 +477,23 @@ class MilvusRetriever(Retriever):
# Use similarity_search_by_vector for lightweight listing. # Use similarity_search_by_vector for lightweight listing.
# If a query is provided embed it; else use a zero vector. # If a query is provided embed it; else use a zero vector.
docs: Iterable[Any] = self.client.similarity_search( docs: Iterable[Any] = self.client.similarity_search(
query, k=100, expr="source == 'examples'" # Limit to 100 results query,
k=100,
expr="source == 'examples'", # Limit to 100 results
) )
for d in docs: for d in docs:
meta = getattr(d, "metadata", {}) or {} meta = getattr(d, "metadata", {}) or {}
# check if the resource is in the list of resources # check if the resource is in the list of resources
if resources and any( if resources and any(
r.uri == meta.get(self.url_field, "") r.uri == meta.get(self.url_field, "")
or r.uri == f"milvus://{meta.get(self.id_field,'')}" or r.uri == f"milvus://{meta.get(self.id_field, '')}"
for r in resources for r in resources
): ):
continue continue
resources.append( resources.append(
Resource( Resource(
uri=meta.get(self.url_field, "") uri=meta.get(self.url_field, "")
or f"milvus://{meta.get(self.id_field,'')}", or f"milvus://{meta.get(self.id_field, '')}",
title=meta.get(self.title_field, "") title=meta.get(self.title_field, "")
or meta.get(self.id_field, "Unnamed"), or meta.get(self.id_field, "Unnamed"),
description="Stored Milvus document", description="Stored Milvus document",

View File

@@ -11,10 +11,10 @@ from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
from langgraph.types import Command
from langgraph.store.memory import InMemoryStore
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.memory import InMemoryStore
from langgraph.types import Command
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
from src.config.configuration import get_recursion_limit from src.config.configuration import get_recursion_limit
@@ -22,6 +22,7 @@ from src.config.loader import get_bool_env, get_str_env
from src.config.report_style import ReportStyle from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory from src.graph.builder import build_graph_with_memory
from src.graph.checkpoint import chat_stream_message
from src.llms.llm import get_configured_llm_models from src.llms.llm import get_configured_llm_models
from src.podcast.graph.builder import build_graph as build_podcast_graph from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph from src.ppt.graph.builder import build_graph as build_ppt_graph
@@ -47,7 +48,6 @@ from src.server.rag_request import (
RAGResourcesResponse, RAGResourcesResponse,
) )
from src.tools import VolcengineTTS from src.tools import VolcengineTTS
from src.graph.checkpoint import chat_stream_message
from src.utils.json_utils import sanitize_args from src.utils.json_utils import sanitize_args
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -97,7 +97,8 @@ async def load_mcp_tools(
) )
return await _get_tools_from_client_session( return await _get_tools_from_client_session(
sse_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds sse_client(url=url, headers=headers, timeout=timeout_seconds),
timeout_seconds,
) )
elif server_type == "streamable_http": elif server_type == "streamable_http":
@@ -107,7 +108,10 @@ async def load_mcp_tools(
) )
return await _get_tools_from_client_session( return await _get_tools_from_client_session(
streamablehttp_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds, streamablehttp_client(
url=url, headers=headers, timeout=timeout_seconds
),
timeout_seconds,
) )
else: else:

View File

@@ -4,6 +4,7 @@
import json import json
import logging import logging
from typing import Any from typing import Any
import json_repair import json_repair
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,13 +1,15 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import pytest
import tempfile
import shutil import shutil
import tempfile
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from typing import Dict, Any, Optional
import psycopg import psycopg
import pytest
class PostgreSQLMockInstance: class PostgreSQLMockInstance:
"""Utility class for managing PostgreSQL mock instances.""" """Utility class for managing PostgreSQL mock instances."""
@@ -24,7 +26,7 @@ class PostgreSQLMockInstance:
self.mock_data = { self.mock_data = {
"chat_streams": {}, # thread_id -> record "chat_streams": {}, # thread_id -> record
"table_exists": False, "table_exists": False,
"connection_active": True "connection_active": True,
} }
def connect(self) -> MagicMock: def connect(self) -> MagicMock:
@@ -70,7 +72,9 @@ class PostgreSQLMockInstance:
if params and len(params) > 0: if params and len(params) > 0:
thread_id = params[0] thread_id = params[0]
if thread_id in self.mock_data["chat_streams"]: if thread_id in self.mock_data["chat_streams"]:
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id] self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
thread_id
]
else: else:
self._mock_cursor._fetch_result = None self._mock_cursor._fetch_result = None
else: else:
@@ -84,7 +88,7 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = { self.mock_data["chat_streams"][thread_id] = {
"id": thread_id, "id": thread_id,
"thread_id": thread_id, "thread_id": thread_id,
"messages": messages "messages": messages,
} }
self._mock_cursor.rowcount = 1 self._mock_cursor.rowcount = 1
else: else:
@@ -97,13 +101,13 @@ class PostgreSQLMockInstance:
self.mock_data["chat_streams"][thread_id] = { self.mock_data["chat_streams"][thread_id] = {
"id": thread_id, "id": thread_id,
"thread_id": thread_id, "thread_id": thread_id,
"messages": messages "messages": messages,
} }
self._mock_cursor.rowcount = 1 self._mock_cursor.rowcount = 1
def _mock_fetchone(self): def _mock_fetchone(self):
"""Mock fetchone operation.""" """Mock fetchone operation."""
return getattr(self._mock_cursor, '_fetch_result', None) return getattr(self._mock_cursor, "_fetch_result", None)
def disconnect(self): def disconnect(self):
"""Cleanup mock connection.""" """Cleanup mock connection."""
@@ -129,6 +133,7 @@ class PostgreSQLMockInstance:
if thread_id: if thread_id:
self.mock_data["chat_streams"][thread_id] = record self.mock_data["chat_streams"][thread_id] = record
@pytest.fixture @pytest.fixture
def mock_postgresql(): def mock_postgresql():
"""Create a PostgreSQL mock instance.""" """Create a PostgreSQL mock instance."""
@@ -137,6 +142,7 @@ def mock_postgresql():
yield instance yield instance
instance.disconnect() instance.disconnect()
@pytest.fixture @pytest.fixture
def clean_mock_postgresql(): def clean_mock_postgresql():
"""Create a clean PostgreSQL mock instance that resets between tests.""" """Create a clean PostgreSQL mock instance that resets between tests."""

View File

@@ -2,15 +2,18 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import os import os
import pytest from unittest.mock import MagicMock, patch
import mongomock import mongomock
from unittest.mock import patch, MagicMock import pytest
import src.graph.checkpoint as checkpoint
from postgres_mock_utils import PostgreSQLMockInstance from postgres_mock_utils import PostgreSQLMockInstance
import src.graph.checkpoint as checkpoint
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db" POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin" MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
def has_real_db_connection(): def has_real_db_connection():
# Check the environment if the MongoDB server is available # Check the environment if the MongoDB server is available
enabled = os.getenv("DB_TESTS_ENABLED", "false") enabled = os.getenv("DB_TESTS_ENABLED", "false")
@@ -18,23 +21,24 @@ def has_real_db_connection():
return True return True
return False return False
def test_with_local_postgres_db(): def test_with_local_postgres_db():
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB.""" """Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
with patch('psycopg.connect') as mock_connect: with patch("psycopg.connect") as mock_connect:
# Setup mock PostgreSQL connection # Setup mock PostgreSQL connection
pg_mock = PostgreSQLMockInstance() pg_mock = PostgreSQLMockInstance()
mock_connect.return_value = pg_mock.connect() mock_connect.return_value = pg_mock.connect()
manager = checkpoint.ChatStreamManager( manager = checkpoint.ChatStreamManager(
checkpoint_saver=True, checkpoint_saver=True,
db_uri=POSTGRES_URL, db_uri=POSTGRES_URL,
) )
assert manager.postgres_conn is not None assert manager.postgres_conn is not None
assert manager.mongo_client is None assert manager.mongo_client is None
def test_with_local_mongo_db(): def test_with_local_mongo_db():
"""Ensure the ChatStreamManager can be initialized with a local MongoDB.""" """Ensure the ChatStreamManager can be initialized with a local MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock # Setup mongomock
mock_client = mongomock.MongoClient() mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client mock_mongo_client.return_value = mock_client
@@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
def test_process_stream_partial_buffer_mongo(): def test_process_stream_partial_buffer_mongo():
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB.""" """Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock # Setup mongomock
mock_client = mongomock.MongoClient() mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client mock_mongo_client.return_value = mock_client
@@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo():
values = [it.dict()["value"] for it in items] values = [it.dict()["value"] for it in items]
assert "hello" in values assert "hello" in values
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
@pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_local_db(): def test_persist_postgresql_local_db():
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB.""" """Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
manager = checkpoint.ChatStreamManager( manager = checkpoint.ChatStreamManager(
@@ -116,7 +123,9 @@ def test_persist_postgresql_local_db():
assert result is True assert result is True
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available") @pytest.mark.skipif(
not has_real_db_connection(), reason="PostgreSQL Server is not available"
)
def test_persist_postgresql_called_with_aggregated_chunks(): def test_persist_postgresql_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method.""" """On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
manager = checkpoint.ChatStreamManager( manager = checkpoint.ChatStreamManager(
@@ -151,7 +160,7 @@ def test_persist_not_attempted_when_saver_disabled():
def test_persist_mongodb_local_db(): def test_persist_mongodb_local_db():
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB.""" """Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock # Setup mongomock
mock_client = mongomock.MongoClient() mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client mock_mongo_client.return_value = mock_client
@@ -184,7 +193,9 @@ def test_persist_mongodb_local_db():
assert doc["messages"] == ["Another message."] assert doc["messages"] == ["Another message."]
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available") @pytest.mark.skipif(
not has_real_db_connection(), reason="MongoDB server is not available"
)
def test_persist_mongodb_called_with_aggregated_chunks(): def test_persist_mongodb_called_with_aggregated_chunks():
"""On 'stop', aggregated chunks should be passed to MongoDB persist method.""" """On 'stop', aggregated chunks should be passed to MongoDB persist method."""
@@ -239,7 +250,7 @@ def test_unsupported_db_uri_scheme():
def test_process_stream_with_interrupt_finish_reason(): def test_process_stream_with_interrupt_finish_reason():
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'.""" """Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock # Setup mongomock
mock_client = mongomock.MongoClient() mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client mock_mongo_client.return_value = mock_client
@@ -395,7 +406,7 @@ def test_multiple_threads_isolation():
def test_mongodb_insert_and_update_paths(): def test_mongodb_insert_and_update_paths():
"""Exercise MongoDB insert, update, and exception branches using mongomock.""" """Exercise MongoDB insert, update, and exception branches using mongomock."""
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
# Setup mongomock # Setup mongomock
mock_client = mongomock.MongoClient() mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client mock_mongo_client.return_value = mock_client
@@ -593,7 +604,7 @@ def test_init_mongodb_success_and_failure(monkeypatch):
"""MongoDB init should succeed with mongomock and fail gracefully with errors.""" """MongoDB init should succeed with mongomock and fail gracefully with errors."""
# Success path with mongomock # Success path with mongomock
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_client = mongomock.MongoClient() mock_client = mongomock.MongoClient()
mock_mongo_client.return_value = mock_client mock_mongo_client.return_value = mock_client
@@ -601,7 +612,7 @@ def test_init_mongodb_success_and_failure(monkeypatch):
assert manager.mongo_db is not None assert manager.mongo_db is not None
# Failure path # Failure path
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
mock_mongo_client.side_effect = RuntimeError("Connection failed") mock_mongo_client.side_effect = RuntimeError("Connection failed")
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL) manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)

View File

@@ -2,23 +2,21 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import pytest import pytest
from langchain_core.messages import ( from langchain_core.messages import (
AIMessageChunk, AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk, HumanMessageChunk,
SystemMessageChunk, SystemMessageChunk,
FunctionMessageChunk,
ToolMessageChunk, ToolMessageChunk,
) )
from src.llms import llm as llm_module from src.llms import llm as llm_module
from langchain_core.messages import ChatMessageChunk
from src.llms.providers import dashscope as dashscope_module from src.llms.providers import dashscope as dashscope_module
from src.llms.providers.dashscope import ( from src.llms.providers.dashscope import (
ChatDashscope, ChatDashscope,
_convert_delta_to_message_chunk,
_convert_chunk_to_generation_chunk, _convert_chunk_to_generation_chunk,
_convert_delta_to_message_chunk,
) )

154
tests/unit/rag/test_dify.py Normal file
View File

@@ -0,0 +1,154 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import MagicMock, patch
import pytest
from src.rag.dify import DifyProvider, parse_uri
# Dummy classes to mock dependencies
class DummyResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class DummyChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class DummyDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch imports in dify.py to use dummy classes
@pytest.fixture(autouse=True)
def patch_imports(monkeypatch):
import src.rag.dify as dify
dify.Resource = DummyResource
dify.Chunk = DummyChunk
dify.Document = DummyDocument
yield
def test_parse_uri_valid():
uri = "rag://dataset/123#abc"
dataset_id, document_id = parse_uri(uri)
assert dataset_id == "123"
assert document_id == "abc"
def test_parse_uri_invalid():
with pytest.raises(ValueError):
parse_uri("http://dataset/123#abc")
def test_init_env_vars(monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
assert provider.api_url == "http://api"
assert provider.api_key == "key"
def test_init_missing_env(monkeypatch):
monkeypatch.delenv("DIFY_API_URL", raising=False)
monkeypatch.setenv("DIFY_API_KEY", "key")
with pytest.raises(ValueError):
DifyProvider()
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.delenv("DIFY_API_KEY", raising=False)
with pytest.raises(ValueError):
DifyProvider()
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_success(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"records": [
{
"segment": {
"content": "chunk text",
"document": {
"id": "doc456",
"name": "Doc Title",
},
},
"score": 0.9,
}
]
}
mock_post.return_value = mock_response
docs = provider.query_relevant_documents("query", [resource])
assert len(docs) == 1
assert docs[0].id == "doc456"
assert docs[0].title == "Doc Title"
assert len(docs[0].chunks) == 1
assert docs[0].chunks[0].content == "chunk text"
assert docs[0].chunks[0].similarity == 0.9
@patch("src.rag.dify.requests.post")
def test_query_relevant_documents_error(mock_post, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
resource = DummyResource("rag://dataset/123#doc456")
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "error"
mock_post.return_value = mock_response
with pytest.raises(Exception):
provider.query_relevant_documents("query", [resource])
@patch("src.rag.dify.requests.get")
def test_list_resources_success(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [
{"id": "123", "name": "Dataset1", "description": "desc1"},
{"id": "456", "name": "Dataset2", "description": "desc2"},
]
}
mock_get.return_value = mock_response
resources = provider.list_resources()
assert len(resources) == 2
assert resources[0].uri == "rag://dataset/123"
assert resources[0].title == "Dataset1"
assert resources[0].description == "desc1"
assert resources[1].uri == "rag://dataset/456"
assert resources[1].title == "Dataset2"
assert resources[1].description == "desc2"
@patch("src.rag.dify.requests.get")
def test_list_resources_error(mock_get, monkeypatch):
monkeypatch.setenv("DIFY_API_URL", "http://api")
monkeypatch.setenv("DIFY_API_KEY", "key")
provider = DifyProvider()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = "fail"
mock_get.return_value = mock_response
with pytest.raises(Exception):
provider.list_resources()

View File

@@ -2,9 +2,11 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from __future__ import annotations from __future__ import annotations
from uuid import uuid4
from types import SimpleNamespace
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
import pytest import pytest
import src.rag.milvus as milvus_mod import src.rag.milvus as milvus_mod
@@ -13,7 +15,6 @@ from src.rag.retriever import Resource
class DummyEmbedding: class DummyEmbedding:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.kwargs = kwargs self.kwargs = kwargs
@@ -369,9 +370,7 @@ def test_create_collection_lite(monkeypatch):
def list_collections(self): # noqa: D401 def list_collections(self): # noqa: D401
return [] # empty triggers creation return [] # empty triggers creation
def create_collection( def create_collection(self, collection_name, schema, index_params): # noqa: D401
self, collection_name, schema, index_params
): # noqa: D401
created["name"] = collection_name created["name"] = collection_name
created["schema"] = schema created["schema"] = schema
created["index"] = index_params created["index"] = index_params