diff --git a/.env.example b/.env.example index 8fb6ee6..a4cd0b2 100644 --- a/.env.example +++ b/.env.example @@ -41,13 +41,18 @@ TAVILY_API_KEY=tvly-xxx # RAGFLOW_RETRIEVAL_SIZE=10 # RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean # Optional. To use RAGFlow's cross-language search, please separate each language with a single comma +# RAG_PROVIDER=dify +# DIFY_API_URL="https://api.dify.ai/v1" +# DIFY_API_KEY="dataset-xxx" + # MOI is a hybrid database that mainly serves enterprise users (https://www.matrixorigin.io/matrixone-intelligence) # RAG_PROVIDER=moi -# MOI_API_URL="https://freetier-01.cn-hangzhou.cluster.matrixonecloud.cn" +# MOI_API_URL="https://cluster.matrixonecloud.cn" # MOI_API_KEY="xxx-xxx-xxx-xxx" # MOI_RETRIEVAL_SIZE=10 # MOI_LIST_LIMIT=10 + # RAG_PROVIDER: milvus (using free milvus instance on zilliz cloud: https://docs.zilliz.com/docs/quick-start ) # RAG_PROVIDER=milvus # MILVUS_URI= diff --git a/server.py b/server.py index f965b54..71848ec 100644 --- a/server.py +++ b/server.py @@ -4,10 +4,11 @@ """ Server script for running the DeerFlow API. """ -import os -import asyncio + import argparse +import asyncio import logging +import os import signal import sys diff --git a/src/config/configuration.py b/src/config/configuration.py index 5c570f4..e7845d5 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -8,9 +8,9 @@ from typing import Any, Optional from langchain_core.runnables import RunnableConfig +from src.config.loader import get_bool_env, get_int_env, get_str_env from src.config.report_style import ReportStyle from src.rag.retriever import Resource -from src.config.loader import get_str_env, get_int_env, get_bool_env logger = logging.getLogger(__name__) diff --git a/src/config/tools.py b/src/config/tools.py index be5c9f5..26e3dc7 100644 --- a/src/config/tools.py +++ b/src/config/tools.py @@ -22,6 +22,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value) class RAGProvider(enum.Enum): + DIFY = "dify" RAGFLOW = "ragflow" VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base" MOI = "moi" diff --git a/src/graph/checkpoint.py b/src/graph/checkpoint.py index 4a28091..9cec30a 100644 --- a/src/graph/checkpoint.py +++ b/src/graph/checkpoint.py @@ -6,10 +6,12 @@ import logging import uuid from datetime import datetime from typing import List, Optional, Tuple + import psycopg +from langgraph.store.memory import InMemoryStore from psycopg.rows import dict_row from pymongo import MongoClient -from langgraph.store.memory import InMemoryStore + from src.config.loader import get_bool_env, get_str_env diff --git a/src/llms/llm.py b/src/llms/llm.py index 291796c..7fd1f54 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel from langchain_deepseek import ChatDeepSeek from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import AzureChatOpenAI, ChatOpenAI -from typing import get_args from src.config import load_yaml_config from src.config.agents import LLMType diff --git a/src/llms/providers/dashscope.py b/src/llms/providers/dashscope.py index 888359f..edfa1a3 100644 --- a/src/llms/providers/dashscope.py +++ b/src/llms/providers/dashscope.py @@ -211,7 +211,6 @@ class ChatDashscope(ChatOpenAI): and hasattr(response.choices[0], "message") and hasattr(response.choices[0].message, "reasoning_content") ): - reasoning_content = response.choices[0].message.reasoning_content if reasoning_content and chat_result.generations: chat_result.generations[0].message.additional_kwargs[ diff --git a/src/rag/__init__.py b/src/rag/__init__.py index 4451543..e1c71c9 100644 --- a/src/rag/__init__.py +++ b/src/rag/__init__.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: MIT from .builder import build_retriever +from .dify import DifyProvider from .ragflow import RAGFlowProvider from .moi import MOIProvider from .retriever import Chunk, Document, Resource, Retriever @@ -11,6 +12,7 @@ __all__ = [ Retriever, Document, Resource, + DifyProvider, RAGFlowProvider, MOIProvider, VikingDBKnowledgeBaseProvider, diff --git a/src/rag/builder.py b/src/rag/builder.py index d3e2f15..a29142f 100644 --- a/src/rag/builder.py +++ b/src/rag/builder.py @@ -2,14 +2,17 @@ # SPDX-License-Identifier: MIT from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider +from src.rag.dify import DifyProvider +from src.rag.milvus import MilvusProvider from src.rag.ragflow import RAGFlowProvider from src.rag.moi import MOIProvider from src.rag.retriever import Retriever from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider -from src.rag.milvus import MilvusProvider def build_retriever() -> Retriever | None: + if SELECTED_RAG_PROVIDER == RAGProvider.DIFY.value: + return DifyProvider() if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value: return RAGFlowProvider() elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value: diff --git a/src/rag/dify.py b/src/rag/dify.py new file mode 100644 index 0000000..527e16e --- /dev/null +++ b/src/rag/dify.py @@ -0,0 +1,132 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import os +from urllib.parse import urlparse + +import requests + +from src.rag.retriever import Chunk, Document, Resource, Retriever + + +class DifyProvider(Retriever): + """ + DifyProvider is a provider that uses dify to retrieve documents. + """ + + api_url: str + api_key: str + + def __init__(self): + api_url = os.getenv("DIFY_API_URL") + if not api_url: + raise ValueError("DIFY_API_URL is not set") + self.api_url = api_url + + api_key = os.getenv("DIFY_API_KEY") + if not api_key: + raise ValueError("DIFY_API_KEY is not set") + self.api_key = api_key + + def query_relevant_documents( + self, query: str, resources: list[Resource] = [] + ) -> list[Document]: + if not resources: + return [] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + all_documents = {} + for resource in resources: + dataset_id, _ = parse_uri(resource.uri) + payload = { + "query": query, + "retrieval_model": { + "search_method": "hybrid_search", + "reranking_enable": False, + "weights": { + "weight_type": "customized", + "keyword_setting": {"keyword_weight": 0.3}, + "vector_setting": {"vector_weight": 0.7}, + }, + "top_k": 3, + "score_threshold_enabled": True, + "score_threshold": 0.5, + }, + } + + response = requests.post( + f"{self.api_url}/datasets/{dataset_id}/retrieve", + headers=headers, + json=payload, + ) + + if response.status_code != 200: + raise Exception(f"Failed to query documents: {response.text}") + + result = response.json() + records = result.get("records", {}) + for record in records: + segment = record.get("segment") + if not segment: + continue + document_info = segment.get("document") + if not document_info: + continue + doc_id = document_info.get("id") + doc_name = document_info.get("name") + if not doc_id or not doc_name: + continue + + if doc_id not in all_documents: + all_documents[doc_id] = Document( + id=doc_id, title=doc_name, chunks=[] + ) + + chunk = Chunk( + content=segment.get("content", ""), + similarity=record.get("score", 0.0), + ) + all_documents[doc_id].chunks.append(chunk) + + return list(all_documents.values()) + + def list_resources(self, query: str | None = None) -> list[Resource]: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + params = {} + if query: + params["keyword"] = query + + response = requests.get( + f"{self.api_url}/datasets", headers=headers, params=params + ) + + if response.status_code != 200: + raise Exception(f"Failed to list resources: {response.text}") + + result = response.json() + resources = [] + + for item in result.get("data", []): + item = Resource( + uri=f"rag://dataset/{item.get('id')}", + title=item.get("name", ""), + description=item.get("description", ""), + ) + resources.append(item) + + return resources + + +def parse_uri(uri: str) -> tuple[str, str]: + parsed = urlparse(uri) + if parsed.scheme != "rag": + raise ValueError(f"Invalid URI: {uri}") + return parsed.path.split("/")[1], parsed.fragment diff --git a/src/rag/milvus.py b/src/rag/milvus.py index 7003ad9..de589d4 100644 --- a/src/rag/milvus.py +++ b/src/rag/milvus.py @@ -7,11 +7,12 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Set from langchain_milvus.vectorstores import Milvus as LangchainMilvus -from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType from langchain_openai import OpenAIEmbeddings from openai import OpenAI +from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient + +from src.config.loader import get_bool_env, get_int_env, get_str_env from src.rag.retriever import Chunk, Document, Resource, Retriever -from src.config.loader import get_bool_env, get_str_env, get_int_env logger = logging.getLogger(__name__) @@ -466,7 +467,7 @@ class MilvusRetriever(Retriever): resources.append( Resource( uri=r.get(self.url_field, "") - or f"milvus://{r.get(self.id_field,'')}", + or f"milvus://{r.get(self.id_field, '')}", title=r.get(self.title_field, "") or r.get(self.id_field, "Unnamed"), description="Stored Milvus document", @@ -476,21 +477,23 @@ class MilvusRetriever(Retriever): # Use similarity_search_by_vector for lightweight listing. # If a query is provided embed it; else use a zero vector. docs: Iterable[Any] = self.client.similarity_search( - query, k=100, expr="source == 'examples'" # Limit to 100 results + query, + k=100, + expr="source == 'examples'", # Limit to 100 results ) for d in docs: meta = getattr(d, "metadata", {}) or {} # check if the resource is in the list of resources if resources and any( r.uri == meta.get(self.url_field, "") - or r.uri == f"milvus://{meta.get(self.id_field,'')}" + or r.uri == f"milvus://{meta.get(self.id_field, '')}" for r in resources ): continue resources.append( Resource( uri=meta.get(self.url_field, "") - or f"milvus://{meta.get(self.id_field,'')}", + or f"milvus://{meta.get(self.id_field, '')}", title=meta.get(self.title_field, "") or meta.get(self.id_field, "Unnamed"), description="Stored Milvus document", diff --git a/src/server/app.py b/src/server/app.py index b7faf80..d68412d 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -11,10 +11,10 @@ from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, StreamingResponse from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage -from langgraph.types import Command -from langgraph.store.memory import InMemoryStore from langgraph.checkpoint.mongodb import AsyncMongoDBSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from langgraph.store.memory import InMemoryStore +from langgraph.types import Command from psycopg_pool import AsyncConnectionPool from src.config.configuration import get_recursion_limit @@ -22,6 +22,7 @@ from src.config.loader import get_bool_env, get_str_env from src.config.report_style import ReportStyle from src.config.tools import SELECTED_RAG_PROVIDER from src.graph.builder import build_graph_with_memory +from src.graph.checkpoint import chat_stream_message from src.llms.llm import get_configured_llm_models from src.podcast.graph.builder import build_graph as build_podcast_graph from src.ppt.graph.builder import build_graph as build_ppt_graph @@ -47,7 +48,6 @@ from src.server.rag_request import ( RAGResourcesResponse, ) from src.tools import VolcengineTTS -from src.graph.checkpoint import chat_stream_message from src.utils.json_utils import sanitize_args logger = logging.getLogger(__name__) diff --git a/src/server/mcp_utils.py b/src/server/mcp_utils.py index ce2bf19..c4513ce 100644 --- a/src/server/mcp_utils.py +++ b/src/server/mcp_utils.py @@ -97,7 +97,8 @@ async def load_mcp_tools( ) return await _get_tools_from_client_session( - sse_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds + sse_client(url=url, headers=headers, timeout=timeout_seconds), + timeout_seconds, ) elif server_type == "streamable_http": @@ -107,7 +108,10 @@ async def load_mcp_tools( ) return await _get_tools_from_client_session( - streamablehttp_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds, + streamablehttp_client( + url=url, headers=headers, timeout=timeout_seconds + ), + timeout_seconds, ) else: diff --git a/src/utils/json_utils.py b/src/utils/json_utils.py index a8a2257..0d7e175 100644 --- a/src/utils/json_utils.py +++ b/src/utils/json_utils.py @@ -4,6 +4,7 @@ import json import logging from typing import Any + import json_repair logger = logging.getLogger(__name__) diff --git a/tests/unit/checkpoint/postgres_mock_utils.py b/tests/unit/checkpoint/postgres_mock_utils.py index ab32a4b..d4c7763 100644 --- a/tests/unit/checkpoint/postgres_mock_utils.py +++ b/tests/unit/checkpoint/postgres_mock_utils.py @@ -1,81 +1,85 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT -import pytest -import tempfile import shutil +import tempfile from pathlib import Path +from typing import Any, Dict, Optional from unittest.mock import MagicMock, patch -from typing import Dict, Any, Optional + import psycopg +import pytest + class PostgreSQLMockInstance: """Utility class for managing PostgreSQL mock instances.""" - + def __init__(self, database_name: str = "test_db"): self.database_name = database_name self.temp_dir: Optional[Path] = None self.mock_connection: Optional[MagicMock] = None self.mock_data: Dict[str, Any] = {} self._setup_mock_data() - + def _setup_mock_data(self): """Initialize mock data storage.""" self.mock_data = { "chat_streams": {}, # thread_id -> record "table_exists": False, - "connection_active": True + "connection_active": True, } - + def connect(self) -> MagicMock: """Create a mock PostgreSQL connection.""" self.mock_connection = MagicMock() self._setup_mock_methods() return self.mock_connection - + def _setup_mock_methods(self): """Setup mock methods for PostgreSQL operations.""" if not self.mock_connection: return - + # Mock cursor context manager mock_cursor = MagicMock() mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) mock_cursor.__exit__ = MagicMock(return_value=False) - + # Setup cursor operations mock_cursor.execute = MagicMock(side_effect=self._mock_execute) mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone) mock_cursor.rowcount = 0 - + # Setup connection operations self.mock_connection.cursor = MagicMock(return_value=mock_cursor) self.mock_connection.commit = MagicMock() self.mock_connection.rollback = MagicMock() self.mock_connection.close = MagicMock() - + # Store cursor for external access self._mock_cursor = mock_cursor - + def _mock_execute(self, sql: str, params=None): """Mock SQL execution.""" sql_upper = sql.upper().strip() - + if "CREATE TABLE" in sql_upper: self.mock_data["table_exists"] = True self._mock_cursor.rowcount = 0 - + elif "SELECT" in sql_upper and "chat_streams" in sql_upper: # Mock SELECT query if params and len(params) > 0: thread_id = params[0] if thread_id in self.mock_data["chat_streams"]: - self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id] + self._mock_cursor._fetch_result = self.mock_data["chat_streams"][ + thread_id + ] else: self._mock_cursor._fetch_result = None else: self._mock_cursor._fetch_result = None - + elif "UPDATE" in sql_upper and "chat_streams" in sql_upper: # Mock UPDATE query if params and len(params) >= 2: @@ -84,12 +88,12 @@ class PostgreSQLMockInstance: self.mock_data["chat_streams"][thread_id] = { "id": thread_id, "thread_id": thread_id, - "messages": messages + "messages": messages, } self._mock_cursor.rowcount = 1 else: self._mock_cursor.rowcount = 0 - + elif "INSERT" in sql_upper and "chat_streams" in sql_upper: # Mock INSERT query if params and len(params) >= 2: @@ -97,30 +101,30 @@ class PostgreSQLMockInstance: self.mock_data["chat_streams"][thread_id] = { "id": thread_id, "thread_id": thread_id, - "messages": messages + "messages": messages, } self._mock_cursor.rowcount = 1 - + def _mock_fetchone(self): """Mock fetchone operation.""" - return getattr(self._mock_cursor, '_fetch_result', None) - + return getattr(self._mock_cursor, "_fetch_result", None) + def disconnect(self): """Cleanup mock connection.""" if self.mock_connection: self.mock_connection.close() self._setup_mock_data() # Reset data - + def reset_data(self): """Reset all mock data.""" self._setup_mock_data() - + def get_table_count(self, table_name: str) -> int: """Get record count in a table.""" if table_name == "chat_streams": return len(self.mock_data["chat_streams"]) return 0 - + def create_test_data(self, table_name: str, records: list): """Insert test data into a table.""" if table_name == "chat_streams": @@ -129,6 +133,7 @@ class PostgreSQLMockInstance: if thread_id: self.mock_data["chat_streams"][thread_id] = record + @pytest.fixture def mock_postgresql(): """Create a PostgreSQL mock instance.""" @@ -137,6 +142,7 @@ def mock_postgresql(): yield instance instance.disconnect() + @pytest.fixture def clean_mock_postgresql(): """Create a clean PostgreSQL mock instance that resets between tests.""" @@ -144,4 +150,4 @@ def clean_mock_postgresql(): instance.connect() instance.reset_data() yield instance - instance.disconnect() \ No newline at end of file + instance.disconnect() diff --git a/tests/unit/checkpoint/test_checkpoint.py b/tests/unit/checkpoint/test_checkpoint.py index 1dbbc92..11c4cbf 100644 --- a/tests/unit/checkpoint/test_checkpoint.py +++ b/tests/unit/checkpoint/test_checkpoint.py @@ -2,15 +2,18 @@ # SPDX-License-Identifier: MIT import os -import pytest +from unittest.mock import MagicMock, patch + import mongomock -from unittest.mock import patch, MagicMock -import src.graph.checkpoint as checkpoint +import pytest from postgres_mock_utils import PostgreSQLMockInstance +import src.graph.checkpoint as checkpoint + POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db" MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin" + def has_real_db_connection(): # Check the environment if the MongoDB server is available enabled = os.getenv("DB_TESTS_ENABLED", "false") @@ -18,27 +21,28 @@ def has_real_db_connection(): return True return False + def test_with_local_postgres_db(): """Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB.""" - with patch('psycopg.connect') as mock_connect: + with patch("psycopg.connect") as mock_connect: # Setup mock PostgreSQL connection pg_mock = PostgreSQLMockInstance() mock_connect.return_value = pg_mock.connect() manager = checkpoint.ChatStreamManager( checkpoint_saver=True, db_uri=POSTGRES_URL, - ) + ) assert manager.postgres_conn is not None assert manager.mongo_client is None def test_with_local_mongo_db(): """Ensure the ChatStreamManager can be initialized with a local MongoDB.""" - with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: + with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client: # Setup mongomock mock_client = mongomock.MongoClient() mock_mongo_client.return_value = mock_client - + manager = checkpoint.ChatStreamManager( checkpoint_saver=True, db_uri=MONGO_URL, @@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch): def test_process_stream_partial_buffer_mongo(): """Partial chunks should be buffered; Use mongomock instead of real MongoDB.""" - with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: + with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client: # Setup mongomock mock_client = mongomock.MongoClient() mock_mongo_client.return_value = mock_client @@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo(): values = [it.dict()["value"] for it in items] assert "hello" in values -@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available") + +@pytest.mark.skipif( + not has_real_db_connection(), reason="PostgreSQL Server is not available" +) def test_persist_postgresql_local_db(): """Ensure that the ChatStreamManager can persist to a local PostgreSQL DB.""" manager = checkpoint.ChatStreamManager( @@ -116,7 +123,9 @@ def test_persist_postgresql_local_db(): assert result is True -@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available") +@pytest.mark.skipif( + not has_real_db_connection(), reason="PostgreSQL Server is not available" +) def test_persist_postgresql_called_with_aggregated_chunks(): """On 'stop', aggregated chunks should be passed to PostgreSQL persist method.""" manager = checkpoint.ChatStreamManager( @@ -151,40 +160,42 @@ def test_persist_not_attempted_when_saver_disabled(): def test_persist_mongodb_local_db(): """Ensure that the ChatStreamManager can persist to a mocked MongoDB.""" - with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: + with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client: # Setup mongomock mock_client = mongomock.MongoClient() mock_mongo_client.return_value = mock_client - + manager = checkpoint.ChatStreamManager( checkpoint_saver=True, db_uri=MONGO_URL, ) assert manager.mongo_db is not None assert manager.postgres_conn is None - + # Simulate a message to persist thread_id = "test_thread" messages = ["This is a test message."] result = manager._persist_to_mongodb(thread_id, messages) assert result is True - + # Verify data was persisted in mock collection = manager.mongo_db.chat_streams doc = collection.find_one({"thread_id": thread_id}) assert doc is not None assert doc["messages"] == messages - + # Simulate a message with existing thread result = manager._persist_to_mongodb(thread_id, ["Another message."]) assert result is True - + # Verify update worked doc = collection.find_one({"thread_id": thread_id}) assert doc["messages"] == ["Another message."] -@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available") +@pytest.mark.skipif( + not has_real_db_connection(), reason="MongoDB server is not available" +) def test_persist_mongodb_called_with_aggregated_chunks(): """On 'stop', aggregated chunks should be passed to MongoDB persist method.""" @@ -239,11 +250,11 @@ def test_unsupported_db_uri_scheme(): def test_process_stream_with_interrupt_finish_reason(): """Test that 'interrupt' finish_reason triggers persistence like 'stop'.""" - with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: + with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client: # Setup mongomock mock_client = mongomock.MongoClient() mock_mongo_client.return_value = mock_client - + manager = checkpoint.ChatStreamManager( checkpoint_saver=True, db_uri=MONGO_URL, @@ -263,7 +274,7 @@ def test_process_stream_with_interrupt_finish_reason(): ) is True ) - + # Verify persistence occurred collection = manager.mongo_db.chat_streams doc = collection.find_one({"thread_id": "int_test"}) @@ -395,7 +406,7 @@ def test_multiple_threads_isolation(): def test_mongodb_insert_and_update_paths(): """Exercise MongoDB insert, update, and exception branches using mongomock.""" - with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: + with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client: # Setup mongomock mock_client = mongomock.MongoClient() mock_mongo_client.return_value = mock_client @@ -404,7 +415,7 @@ def test_mongodb_insert_and_update_paths(): # Insert success (new thread) assert manager._persist_to_mongodb("th1", ["message1"]) is True - + # Verify insert worked collection = manager.mongo_db.chat_streams doc = collection.find_one({"thread_id": "th1"}) @@ -413,7 +424,7 @@ def test_mongodb_insert_and_update_paths(): # Update success (existing thread) assert manager._persist_to_mongodb("th1", ["message2"]) is True - + # Verify update worked doc = collection.find_one({"thread_id": "th1"}) assert doc["messages"] == ["message2"] @@ -421,9 +432,9 @@ def test_mongodb_insert_and_update_paths(): # Test error case by mocking collection methods original_find_one = collection.find_one collection.find_one = MagicMock(side_effect=RuntimeError("Database error")) - + assert manager._persist_to_mongodb("th2", ["message"]) is False - + # Restore original method collection.find_one = original_find_one @@ -591,19 +602,19 @@ def test_context_manager_calls_close(monkeypatch): def test_init_mongodb_success_and_failure(monkeypatch): """MongoDB init should succeed with mongomock and fail gracefully with errors.""" - + # Success path with mongomock - with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: + with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client: mock_client = mongomock.MongoClient() mock_mongo_client.return_value = mock_client - + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL) assert manager.mongo_db is not None # Failure path - with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client: + with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client: mock_mongo_client.side_effect = RuntimeError("Connection failed") - + manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL) # Should have no mongo_db set on failure assert getattr(manager, "mongo_db", None) is None diff --git a/tests/unit/llms/test_dashscope.py b/tests/unit/llms/test_dashscope.py index fd1129c..be844be 100644 --- a/tests/unit/llms/test_dashscope.py +++ b/tests/unit/llms/test_dashscope.py @@ -2,23 +2,21 @@ # SPDX-License-Identifier: MIT import pytest - from langchain_core.messages import ( AIMessageChunk, + ChatMessageChunk, + FunctionMessageChunk, HumanMessageChunk, SystemMessageChunk, - FunctionMessageChunk, ToolMessageChunk, ) from src.llms import llm as llm_module -from langchain_core.messages import ChatMessageChunk from src.llms.providers import dashscope as dashscope_module - from src.llms.providers.dashscope import ( ChatDashscope, - _convert_delta_to_message_chunk, _convert_chunk_to_generation_chunk, + _convert_delta_to_message_chunk, ) diff --git a/tests/unit/rag/test_dify.py b/tests/unit/rag/test_dify.py new file mode 100644 index 0000000..4aa146b --- /dev/null +++ b/tests/unit/rag/test_dify.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from unittest.mock import MagicMock, patch + +import pytest + +from src.rag.dify import DifyProvider, parse_uri + + +# Dummy classes to mock dependencies +class DummyResource: + def __init__(self, uri, title="", description=""): + self.uri = uri + self.title = title + self.description = description + + +class DummyChunk: + def __init__(self, content, similarity): + self.content = content + self.similarity = similarity + + +class DummyDocument: + def __init__(self, id, title, chunks=None): + self.id = id + self.title = title + self.chunks = chunks or [] + + +# Patch imports in dify.py to use dummy classes +@pytest.fixture(autouse=True) +def patch_imports(monkeypatch): + import src.rag.dify as dify + + dify.Resource = DummyResource + dify.Chunk = DummyChunk + dify.Document = DummyDocument + yield + + +def test_parse_uri_valid(): + uri = "rag://dataset/123#abc" + dataset_id, document_id = parse_uri(uri) + assert dataset_id == "123" + assert document_id == "abc" + + +def test_parse_uri_invalid(): + with pytest.raises(ValueError): + parse_uri("http://dataset/123#abc") + + +def test_init_env_vars(monkeypatch): + monkeypatch.setenv("DIFY_API_URL", "http://api") + monkeypatch.setenv("DIFY_API_KEY", "key") + provider = DifyProvider() + assert provider.api_url == "http://api" + assert provider.api_key == "key" + + +def test_init_missing_env(monkeypatch): + monkeypatch.delenv("DIFY_API_URL", raising=False) + monkeypatch.setenv("DIFY_API_KEY", "key") + with pytest.raises(ValueError): + DifyProvider() + monkeypatch.setenv("DIFY_API_URL", "http://api") + monkeypatch.delenv("DIFY_API_KEY", raising=False) + with pytest.raises(ValueError): + DifyProvider() + + +@patch("src.rag.dify.requests.post") +def test_query_relevant_documents_success(mock_post, monkeypatch): + monkeypatch.setenv("DIFY_API_URL", "http://api") + monkeypatch.setenv("DIFY_API_KEY", "key") + provider = DifyProvider() + resource = DummyResource("rag://dataset/123#doc456") + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "records": [ + { + "segment": { + "content": "chunk text", + "document": { + "id": "doc456", + "name": "Doc Title", + }, + }, + "score": 0.9, + } + ] + } + mock_post.return_value = mock_response + docs = provider.query_relevant_documents("query", [resource]) + assert len(docs) == 1 + assert docs[0].id == "doc456" + assert docs[0].title == "Doc Title" + assert len(docs[0].chunks) == 1 + assert docs[0].chunks[0].content == "chunk text" + assert docs[0].chunks[0].similarity == 0.9 + + +@patch("src.rag.dify.requests.post") +def test_query_relevant_documents_error(mock_post, monkeypatch): + monkeypatch.setenv("DIFY_API_URL", "http://api") + monkeypatch.setenv("DIFY_API_KEY", "key") + provider = DifyProvider() + resource = DummyResource("rag://dataset/123#doc456") + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "error" + mock_post.return_value = mock_response + with pytest.raises(Exception): + provider.query_relevant_documents("query", [resource]) + + +@patch("src.rag.dify.requests.get") +def test_list_resources_success(mock_get, monkeypatch): + monkeypatch.setenv("DIFY_API_URL", "http://api") + monkeypatch.setenv("DIFY_API_KEY", "key") + provider = DifyProvider() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + {"id": "123", "name": "Dataset1", "description": "desc1"}, + {"id": "456", "name": "Dataset2", "description": "desc2"}, + ] + } + mock_get.return_value = mock_response + resources = provider.list_resources() + assert len(resources) == 2 + assert resources[0].uri == "rag://dataset/123" + assert resources[0].title == "Dataset1" + assert resources[0].description == "desc1" + assert resources[1].uri == "rag://dataset/456" + assert resources[1].title == "Dataset2" + assert resources[1].description == "desc2" + + +@patch("src.rag.dify.requests.get") +def test_list_resources_error(mock_get, monkeypatch): + monkeypatch.setenv("DIFY_API_URL", "http://api") + monkeypatch.setenv("DIFY_API_KEY", "key") + provider = DifyProvider() + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "fail" + mock_get.return_value = mock_response + with pytest.raises(Exception): + provider.list_resources() diff --git a/tests/unit/rag/test_milvus.py b/tests/unit/rag/test_milvus.py index d55b950..51b7557 100644 --- a/tests/unit/rag/test_milvus.py +++ b/tests/unit/rag/test_milvus.py @@ -2,9 +2,11 @@ # SPDX-License-Identifier: MIT from __future__ import annotations -from uuid import uuid4 -from types import SimpleNamespace + from pathlib import Path +from types import SimpleNamespace +from uuid import uuid4 + import pytest import src.rag.milvus as milvus_mod @@ -13,7 +15,6 @@ from src.rag.retriever import Resource class DummyEmbedding: - def __init__(self, **kwargs): self.kwargs = kwargs @@ -369,9 +370,7 @@ def test_create_collection_lite(monkeypatch): def list_collections(self): # noqa: D401 return [] # empty triggers creation - def create_collection( - self, collection_name, schema, index_params - ): # noqa: D401 + def create_collection(self, collection_name, schema, index_params): # noqa: D401 created["name"] = collection_name created["schema"] = schema created["index"] = index_params