mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-28 16:24:47 +08:00
feat: support dify in rag module (#550)
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -41,13 +41,18 @@ TAVILY_API_KEY=tvly-xxx
|
|||||||
# RAGFLOW_RETRIEVAL_SIZE=10
|
# RAGFLOW_RETRIEVAL_SIZE=10
|
||||||
# RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean # Optional. To use RAGFlow's cross-language search, please separate each language with a single comma
|
# RAGFLOW_CROSS_LANGUAGES=English,Chinese,Spanish,French,German,Japanese,Korean # Optional. To use RAGFlow's cross-language search, please separate each language with a single comma
|
||||||
|
|
||||||
|
# RAG_PROVIDER=dify
|
||||||
|
# DIFY_API_URL="https://api.dify.ai/v1"
|
||||||
|
# DIFY_API_KEY="dataset-xxx"
|
||||||
|
|
||||||
# MOI is a hybrid database that mainly serves enterprise users (https://www.matrixorigin.io/matrixone-intelligence)
|
# MOI is a hybrid database that mainly serves enterprise users (https://www.matrixorigin.io/matrixone-intelligence)
|
||||||
# RAG_PROVIDER=moi
|
# RAG_PROVIDER=moi
|
||||||
# MOI_API_URL="https://freetier-01.cn-hangzhou.cluster.matrixonecloud.cn"
|
# MOI_API_URL="https://cluster.matrixonecloud.cn"
|
||||||
# MOI_API_KEY="xxx-xxx-xxx-xxx"
|
# MOI_API_KEY="xxx-xxx-xxx-xxx"
|
||||||
# MOI_RETRIEVAL_SIZE=10
|
# MOI_RETRIEVAL_SIZE=10
|
||||||
# MOI_LIST_LIMIT=10
|
# MOI_LIST_LIMIT=10
|
||||||
|
|
||||||
|
|
||||||
# RAG_PROVIDER: milvus (using free milvus instance on zilliz cloud: https://docs.zilliz.com/docs/quick-start )
|
# RAG_PROVIDER: milvus (using free milvus instance on zilliz cloud: https://docs.zilliz.com/docs/quick-start )
|
||||||
# RAG_PROVIDER=milvus
|
# RAG_PROVIDER=milvus
|
||||||
# MILVUS_URI=<endpoint_of_self_hosted_milvus_or_zilliz_cloud>
|
# MILVUS_URI=<endpoint_of_self_hosted_milvus_or_zilliz_cloud>
|
||||||
|
|||||||
@@ -4,10 +4,11 @@
|
|||||||
"""
|
"""
|
||||||
Server script for running the DeerFlow API.
|
Server script for running the DeerFlow API.
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
from src.config.loader import get_bool_env, get_int_env, get_str_env
|
||||||
from src.config.report_style import ReportStyle
|
from src.config.report_style import ReportStyle
|
||||||
from src.rag.retriever import Resource
|
from src.rag.retriever import Resource
|
||||||
from src.config.loader import get_str_env, get_int_env, get_bool_env
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
|
|||||||
|
|
||||||
|
|
||||||
class RAGProvider(enum.Enum):
|
class RAGProvider(enum.Enum):
|
||||||
|
DIFY = "dify"
|
||||||
RAGFLOW = "ragflow"
|
RAGFLOW = "ragflow"
|
||||||
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
|
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
|
||||||
MOI = "moi"
|
MOI = "moi"
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import psycopg
|
import psycopg
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
from psycopg.rows import dict_row
|
from psycopg.rows import dict_row
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from langgraph.store.memory import InMemoryStore
|
|
||||||
from src.config.loader import get_bool_env, get_str_env
|
from src.config.loader import get_bool_env, get_str_env
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel
|
|||||||
from langchain_deepseek import ChatDeepSeek
|
from langchain_deepseek import ChatDeepSeek
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||||
from typing import get_args
|
|
||||||
|
|
||||||
from src.config import load_yaml_config
|
from src.config import load_yaml_config
|
||||||
from src.config.agents import LLMType
|
from src.config.agents import LLMType
|
||||||
|
|||||||
@@ -211,7 +211,6 @@ class ChatDashscope(ChatOpenAI):
|
|||||||
and hasattr(response.choices[0], "message")
|
and hasattr(response.choices[0], "message")
|
||||||
and hasattr(response.choices[0].message, "reasoning_content")
|
and hasattr(response.choices[0].message, "reasoning_content")
|
||||||
):
|
):
|
||||||
|
|
||||||
reasoning_content = response.choices[0].message.reasoning_content
|
reasoning_content = response.choices[0].message.reasoning_content
|
||||||
if reasoning_content and chat_result.generations:
|
if reasoning_content and chat_result.generations:
|
||||||
chat_result.generations[0].message.additional_kwargs[
|
chat_result.generations[0].message.additional_kwargs[
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from .builder import build_retriever
|
from .builder import build_retriever
|
||||||
|
from .dify import DifyProvider
|
||||||
from .ragflow import RAGFlowProvider
|
from .ragflow import RAGFlowProvider
|
||||||
from .moi import MOIProvider
|
from .moi import MOIProvider
|
||||||
from .retriever import Chunk, Document, Resource, Retriever
|
from .retriever import Chunk, Document, Resource, Retriever
|
||||||
@@ -11,6 +12,7 @@ __all__ = [
|
|||||||
Retriever,
|
Retriever,
|
||||||
Document,
|
Document,
|
||||||
Resource,
|
Resource,
|
||||||
|
DifyProvider,
|
||||||
RAGFlowProvider,
|
RAGFlowProvider,
|
||||||
MOIProvider,
|
MOIProvider,
|
||||||
VikingDBKnowledgeBaseProvider,
|
VikingDBKnowledgeBaseProvider,
|
||||||
|
|||||||
@@ -2,14 +2,17 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
|
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
|
||||||
|
from src.rag.dify import DifyProvider
|
||||||
|
from src.rag.milvus import MilvusProvider
|
||||||
from src.rag.ragflow import RAGFlowProvider
|
from src.rag.ragflow import RAGFlowProvider
|
||||||
from src.rag.moi import MOIProvider
|
from src.rag.moi import MOIProvider
|
||||||
from src.rag.retriever import Retriever
|
from src.rag.retriever import Retriever
|
||||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||||
from src.rag.milvus import MilvusProvider
|
|
||||||
|
|
||||||
|
|
||||||
def build_retriever() -> Retriever | None:
|
def build_retriever() -> Retriever | None:
|
||||||
|
if SELECTED_RAG_PROVIDER == RAGProvider.DIFY.value:
|
||||||
|
return DifyProvider()
|
||||||
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
|
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
|
||||||
return RAGFlowProvider()
|
return RAGFlowProvider()
|
||||||
elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value:
|
elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value:
|
||||||
|
|||||||
132
src/rag/dify.py
Normal file
132
src/rag/dify.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import os
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||||
|
|
||||||
|
|
||||||
|
class DifyProvider(Retriever):
|
||||||
|
"""
|
||||||
|
DifyProvider is a provider that uses dify to retrieve documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_url: str
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
api_url = os.getenv("DIFY_API_URL")
|
||||||
|
if not api_url:
|
||||||
|
raise ValueError("DIFY_API_URL is not set")
|
||||||
|
self.api_url = api_url
|
||||||
|
|
||||||
|
api_key = os.getenv("DIFY_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("DIFY_API_KEY is not set")
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def query_relevant_documents(
|
||||||
|
self, query: str, resources: list[Resource] = []
|
||||||
|
) -> list[Document]:
|
||||||
|
if not resources:
|
||||||
|
return []
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
all_documents = {}
|
||||||
|
for resource in resources:
|
||||||
|
dataset_id, _ = parse_uri(resource.uri)
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "hybrid_search",
|
||||||
|
"reranking_enable": False,
|
||||||
|
"weights": {
|
||||||
|
"weight_type": "customized",
|
||||||
|
"keyword_setting": {"keyword_weight": 0.3},
|
||||||
|
"vector_setting": {"vector_weight": 0.7},
|
||||||
|
},
|
||||||
|
"top_k": 3,
|
||||||
|
"score_threshold_enabled": True,
|
||||||
|
"score_threshold": 0.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.api_url}/datasets/{dataset_id}/retrieve",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Failed to query documents: {response.text}")
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
records = result.get("records", {})
|
||||||
|
for record in records:
|
||||||
|
segment = record.get("segment")
|
||||||
|
if not segment:
|
||||||
|
continue
|
||||||
|
document_info = segment.get("document")
|
||||||
|
if not document_info:
|
||||||
|
continue
|
||||||
|
doc_id = document_info.get("id")
|
||||||
|
doc_name = document_info.get("name")
|
||||||
|
if not doc_id or not doc_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if doc_id not in all_documents:
|
||||||
|
all_documents[doc_id] = Document(
|
||||||
|
id=doc_id, title=doc_name, chunks=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk = Chunk(
|
||||||
|
content=segment.get("content", ""),
|
||||||
|
similarity=record.get("score", 0.0),
|
||||||
|
)
|
||||||
|
all_documents[doc_id].chunks.append(chunk)
|
||||||
|
|
||||||
|
return list(all_documents.values())
|
||||||
|
|
||||||
|
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
if query:
|
||||||
|
params["keyword"] = query
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.api_url}/datasets", headers=headers, params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Failed to list resources: {response.text}")
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
resources = []
|
||||||
|
|
||||||
|
for item in result.get("data", []):
|
||||||
|
item = Resource(
|
||||||
|
uri=f"rag://dataset/{item.get('id')}",
|
||||||
|
title=item.get("name", ""),
|
||||||
|
description=item.get("description", ""),
|
||||||
|
)
|
||||||
|
resources.append(item)
|
||||||
|
|
||||||
|
return resources
|
||||||
|
|
||||||
|
|
||||||
|
def parse_uri(uri: str) -> tuple[str, str]:
|
||||||
|
parsed = urlparse(uri)
|
||||||
|
if parsed.scheme != "rag":
|
||||||
|
raise ValueError(f"Invalid URI: {uri}")
|
||||||
|
return parsed.path.split("/")[1], parsed.fragment
|
||||||
@@ -7,11 +7,12 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
|
||||||
|
|
||||||
from langchain_milvus.vectorstores import Milvus as LangchainMilvus
|
from langchain_milvus.vectorstores import Milvus as LangchainMilvus
|
||||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
|
||||||
from langchain_openai import OpenAIEmbeddings
|
from langchain_openai import OpenAIEmbeddings
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||||
|
|
||||||
|
from src.config.loader import get_bool_env, get_int_env, get_str_env
|
||||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||||
from src.config.loader import get_bool_env, get_str_env, get_int_env
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -466,7 +467,7 @@ class MilvusRetriever(Retriever):
|
|||||||
resources.append(
|
resources.append(
|
||||||
Resource(
|
Resource(
|
||||||
uri=r.get(self.url_field, "")
|
uri=r.get(self.url_field, "")
|
||||||
or f"milvus://{r.get(self.id_field,'')}",
|
or f"milvus://{r.get(self.id_field, '')}",
|
||||||
title=r.get(self.title_field, "")
|
title=r.get(self.title_field, "")
|
||||||
or r.get(self.id_field, "Unnamed"),
|
or r.get(self.id_field, "Unnamed"),
|
||||||
description="Stored Milvus document",
|
description="Stored Milvus document",
|
||||||
@@ -476,21 +477,23 @@ class MilvusRetriever(Retriever):
|
|||||||
# Use similarity_search_by_vector for lightweight listing.
|
# Use similarity_search_by_vector for lightweight listing.
|
||||||
# If a query is provided embed it; else use a zero vector.
|
# If a query is provided embed it; else use a zero vector.
|
||||||
docs: Iterable[Any] = self.client.similarity_search(
|
docs: Iterable[Any] = self.client.similarity_search(
|
||||||
query, k=100, expr="source == 'examples'" # Limit to 100 results
|
query,
|
||||||
|
k=100,
|
||||||
|
expr="source == 'examples'", # Limit to 100 results
|
||||||
)
|
)
|
||||||
for d in docs:
|
for d in docs:
|
||||||
meta = getattr(d, "metadata", {}) or {}
|
meta = getattr(d, "metadata", {}) or {}
|
||||||
# check if the resource is in the list of resources
|
# check if the resource is in the list of resources
|
||||||
if resources and any(
|
if resources and any(
|
||||||
r.uri == meta.get(self.url_field, "")
|
r.uri == meta.get(self.url_field, "")
|
||||||
or r.uri == f"milvus://{meta.get(self.id_field,'')}"
|
or r.uri == f"milvus://{meta.get(self.id_field, '')}"
|
||||||
for r in resources
|
for r in resources
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
resources.append(
|
resources.append(
|
||||||
Resource(
|
Resource(
|
||||||
uri=meta.get(self.url_field, "")
|
uri=meta.get(self.url_field, "")
|
||||||
or f"milvus://{meta.get(self.id_field,'')}",
|
or f"milvus://{meta.get(self.id_field, '')}",
|
||||||
title=meta.get(self.title_field, "")
|
title=meta.get(self.title_field, "")
|
||||||
or meta.get(self.id_field, "Unnamed"),
|
or meta.get(self.id_field, "Unnamed"),
|
||||||
description="Stored Milvus document",
|
description="Stored Milvus document",
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ from fastapi import FastAPI, HTTPException, Query
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
||||||
from langgraph.types import Command
|
|
||||||
from langgraph.store.memory import InMemoryStore
|
|
||||||
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
|
from langgraph.types import Command
|
||||||
from psycopg_pool import AsyncConnectionPool
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
from src.config.configuration import get_recursion_limit
|
from src.config.configuration import get_recursion_limit
|
||||||
@@ -22,6 +22,7 @@ from src.config.loader import get_bool_env, get_str_env
|
|||||||
from src.config.report_style import ReportStyle
|
from src.config.report_style import ReportStyle
|
||||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||||
from src.graph.builder import build_graph_with_memory
|
from src.graph.builder import build_graph_with_memory
|
||||||
|
from src.graph.checkpoint import chat_stream_message
|
||||||
from src.llms.llm import get_configured_llm_models
|
from src.llms.llm import get_configured_llm_models
|
||||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||||
@@ -47,7 +48,6 @@ from src.server.rag_request import (
|
|||||||
RAGResourcesResponse,
|
RAGResourcesResponse,
|
||||||
)
|
)
|
||||||
from src.tools import VolcengineTTS
|
from src.tools import VolcengineTTS
|
||||||
from src.graph.checkpoint import chat_stream_message
|
|
||||||
from src.utils.json_utils import sanitize_args
|
from src.utils.json_utils import sanitize_args
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -97,7 +97,8 @@ async def load_mcp_tools(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return await _get_tools_from_client_session(
|
return await _get_tools_from_client_session(
|
||||||
sse_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds
|
sse_client(url=url, headers=headers, timeout=timeout_seconds),
|
||||||
|
timeout_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif server_type == "streamable_http":
|
elif server_type == "streamable_http":
|
||||||
@@ -107,7 +108,10 @@ async def load_mcp_tools(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return await _get_tools_from_client_session(
|
return await _get_tools_from_client_session(
|
||||||
streamablehttp_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds,
|
streamablehttp_client(
|
||||||
|
url=url, headers=headers, timeout=timeout_seconds
|
||||||
|
),
|
||||||
|
timeout_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,81 +1,85 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import pytest
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
import psycopg
|
import psycopg
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
class PostgreSQLMockInstance:
|
class PostgreSQLMockInstance:
|
||||||
"""Utility class for managing PostgreSQL mock instances."""
|
"""Utility class for managing PostgreSQL mock instances."""
|
||||||
|
|
||||||
def __init__(self, database_name: str = "test_db"):
|
def __init__(self, database_name: str = "test_db"):
|
||||||
self.database_name = database_name
|
self.database_name = database_name
|
||||||
self.temp_dir: Optional[Path] = None
|
self.temp_dir: Optional[Path] = None
|
||||||
self.mock_connection: Optional[MagicMock] = None
|
self.mock_connection: Optional[MagicMock] = None
|
||||||
self.mock_data: Dict[str, Any] = {}
|
self.mock_data: Dict[str, Any] = {}
|
||||||
self._setup_mock_data()
|
self._setup_mock_data()
|
||||||
|
|
||||||
def _setup_mock_data(self):
|
def _setup_mock_data(self):
|
||||||
"""Initialize mock data storage."""
|
"""Initialize mock data storage."""
|
||||||
self.mock_data = {
|
self.mock_data = {
|
||||||
"chat_streams": {}, # thread_id -> record
|
"chat_streams": {}, # thread_id -> record
|
||||||
"table_exists": False,
|
"table_exists": False,
|
||||||
"connection_active": True
|
"connection_active": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
def connect(self) -> MagicMock:
|
def connect(self) -> MagicMock:
|
||||||
"""Create a mock PostgreSQL connection."""
|
"""Create a mock PostgreSQL connection."""
|
||||||
self.mock_connection = MagicMock()
|
self.mock_connection = MagicMock()
|
||||||
self._setup_mock_methods()
|
self._setup_mock_methods()
|
||||||
return self.mock_connection
|
return self.mock_connection
|
||||||
|
|
||||||
def _setup_mock_methods(self):
|
def _setup_mock_methods(self):
|
||||||
"""Setup mock methods for PostgreSQL operations."""
|
"""Setup mock methods for PostgreSQL operations."""
|
||||||
if not self.mock_connection:
|
if not self.mock_connection:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Mock cursor context manager
|
# Mock cursor context manager
|
||||||
mock_cursor = MagicMock()
|
mock_cursor = MagicMock()
|
||||||
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
mock_cursor.__enter__ = MagicMock(return_value=mock_cursor)
|
||||||
mock_cursor.__exit__ = MagicMock(return_value=False)
|
mock_cursor.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
# Setup cursor operations
|
# Setup cursor operations
|
||||||
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
|
mock_cursor.execute = MagicMock(side_effect=self._mock_execute)
|
||||||
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
|
mock_cursor.fetchone = MagicMock(side_effect=self._mock_fetchone)
|
||||||
mock_cursor.rowcount = 0
|
mock_cursor.rowcount = 0
|
||||||
|
|
||||||
# Setup connection operations
|
# Setup connection operations
|
||||||
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
|
self.mock_connection.cursor = MagicMock(return_value=mock_cursor)
|
||||||
self.mock_connection.commit = MagicMock()
|
self.mock_connection.commit = MagicMock()
|
||||||
self.mock_connection.rollback = MagicMock()
|
self.mock_connection.rollback = MagicMock()
|
||||||
self.mock_connection.close = MagicMock()
|
self.mock_connection.close = MagicMock()
|
||||||
|
|
||||||
# Store cursor for external access
|
# Store cursor for external access
|
||||||
self._mock_cursor = mock_cursor
|
self._mock_cursor = mock_cursor
|
||||||
|
|
||||||
def _mock_execute(self, sql: str, params=None):
|
def _mock_execute(self, sql: str, params=None):
|
||||||
"""Mock SQL execution."""
|
"""Mock SQL execution."""
|
||||||
sql_upper = sql.upper().strip()
|
sql_upper = sql.upper().strip()
|
||||||
|
|
||||||
if "CREATE TABLE" in sql_upper:
|
if "CREATE TABLE" in sql_upper:
|
||||||
self.mock_data["table_exists"] = True
|
self.mock_data["table_exists"] = True
|
||||||
self._mock_cursor.rowcount = 0
|
self._mock_cursor.rowcount = 0
|
||||||
|
|
||||||
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
|
elif "SELECT" in sql_upper and "chat_streams" in sql_upper:
|
||||||
# Mock SELECT query
|
# Mock SELECT query
|
||||||
if params and len(params) > 0:
|
if params and len(params) > 0:
|
||||||
thread_id = params[0]
|
thread_id = params[0]
|
||||||
if thread_id in self.mock_data["chat_streams"]:
|
if thread_id in self.mock_data["chat_streams"]:
|
||||||
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][thread_id]
|
self._mock_cursor._fetch_result = self.mock_data["chat_streams"][
|
||||||
|
thread_id
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self._mock_cursor._fetch_result = None
|
self._mock_cursor._fetch_result = None
|
||||||
else:
|
else:
|
||||||
self._mock_cursor._fetch_result = None
|
self._mock_cursor._fetch_result = None
|
||||||
|
|
||||||
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
|
elif "UPDATE" in sql_upper and "chat_streams" in sql_upper:
|
||||||
# Mock UPDATE query
|
# Mock UPDATE query
|
||||||
if params and len(params) >= 2:
|
if params and len(params) >= 2:
|
||||||
@@ -84,12 +88,12 @@ class PostgreSQLMockInstance:
|
|||||||
self.mock_data["chat_streams"][thread_id] = {
|
self.mock_data["chat_streams"][thread_id] = {
|
||||||
"id": thread_id,
|
"id": thread_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"messages": messages
|
"messages": messages,
|
||||||
}
|
}
|
||||||
self._mock_cursor.rowcount = 1
|
self._mock_cursor.rowcount = 1
|
||||||
else:
|
else:
|
||||||
self._mock_cursor.rowcount = 0
|
self._mock_cursor.rowcount = 0
|
||||||
|
|
||||||
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
|
elif "INSERT" in sql_upper and "chat_streams" in sql_upper:
|
||||||
# Mock INSERT query
|
# Mock INSERT query
|
||||||
if params and len(params) >= 2:
|
if params and len(params) >= 2:
|
||||||
@@ -97,30 +101,30 @@ class PostgreSQLMockInstance:
|
|||||||
self.mock_data["chat_streams"][thread_id] = {
|
self.mock_data["chat_streams"][thread_id] = {
|
||||||
"id": thread_id,
|
"id": thread_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"messages": messages
|
"messages": messages,
|
||||||
}
|
}
|
||||||
self._mock_cursor.rowcount = 1
|
self._mock_cursor.rowcount = 1
|
||||||
|
|
||||||
def _mock_fetchone(self):
|
def _mock_fetchone(self):
|
||||||
"""Mock fetchone operation."""
|
"""Mock fetchone operation."""
|
||||||
return getattr(self._mock_cursor, '_fetch_result', None)
|
return getattr(self._mock_cursor, "_fetch_result", None)
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
"""Cleanup mock connection."""
|
"""Cleanup mock connection."""
|
||||||
if self.mock_connection:
|
if self.mock_connection:
|
||||||
self.mock_connection.close()
|
self.mock_connection.close()
|
||||||
self._setup_mock_data() # Reset data
|
self._setup_mock_data() # Reset data
|
||||||
|
|
||||||
def reset_data(self):
|
def reset_data(self):
|
||||||
"""Reset all mock data."""
|
"""Reset all mock data."""
|
||||||
self._setup_mock_data()
|
self._setup_mock_data()
|
||||||
|
|
||||||
def get_table_count(self, table_name: str) -> int:
|
def get_table_count(self, table_name: str) -> int:
|
||||||
"""Get record count in a table."""
|
"""Get record count in a table."""
|
||||||
if table_name == "chat_streams":
|
if table_name == "chat_streams":
|
||||||
return len(self.mock_data["chat_streams"])
|
return len(self.mock_data["chat_streams"])
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def create_test_data(self, table_name: str, records: list):
|
def create_test_data(self, table_name: str, records: list):
|
||||||
"""Insert test data into a table."""
|
"""Insert test data into a table."""
|
||||||
if table_name == "chat_streams":
|
if table_name == "chat_streams":
|
||||||
@@ -129,6 +133,7 @@ class PostgreSQLMockInstance:
|
|||||||
if thread_id:
|
if thread_id:
|
||||||
self.mock_data["chat_streams"][thread_id] = record
|
self.mock_data["chat_streams"][thread_id] = record
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_postgresql():
|
def mock_postgresql():
|
||||||
"""Create a PostgreSQL mock instance."""
|
"""Create a PostgreSQL mock instance."""
|
||||||
@@ -137,6 +142,7 @@ def mock_postgresql():
|
|||||||
yield instance
|
yield instance
|
||||||
instance.disconnect()
|
instance.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def clean_mock_postgresql():
|
def clean_mock_postgresql():
|
||||||
"""Create a clean PostgreSQL mock instance that resets between tests."""
|
"""Create a clean PostgreSQL mock instance that resets between tests."""
|
||||||
@@ -144,4 +150,4 @@ def clean_mock_postgresql():
|
|||||||
instance.connect()
|
instance.connect()
|
||||||
instance.reset_data()
|
instance.reset_data()
|
||||||
yield instance
|
yield instance
|
||||||
instance.disconnect()
|
instance.disconnect()
|
||||||
|
|||||||
@@ -2,15 +2,18 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pytest
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import mongomock
|
import mongomock
|
||||||
from unittest.mock import patch, MagicMock
|
import pytest
|
||||||
import src.graph.checkpoint as checkpoint
|
|
||||||
from postgres_mock_utils import PostgreSQLMockInstance
|
from postgres_mock_utils import PostgreSQLMockInstance
|
||||||
|
|
||||||
|
import src.graph.checkpoint as checkpoint
|
||||||
|
|
||||||
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
|
POSTGRES_URL = "postgresql://postgres:postgres@localhost:5432/checkpointing_db"
|
||||||
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
|
MONGO_URL = "mongodb://admin:admin@localhost:27017/checkpointing_db?authSource=admin"
|
||||||
|
|
||||||
|
|
||||||
def has_real_db_connection():
|
def has_real_db_connection():
|
||||||
# Check the environment if the MongoDB server is available
|
# Check the environment if the MongoDB server is available
|
||||||
enabled = os.getenv("DB_TESTS_ENABLED", "false")
|
enabled = os.getenv("DB_TESTS_ENABLED", "false")
|
||||||
@@ -18,27 +21,28 @@ def has_real_db_connection():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def test_with_local_postgres_db():
|
def test_with_local_postgres_db():
|
||||||
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
|
"""Ensure the ChatStreamManager can be initialized with a local PostgreSQL DB."""
|
||||||
with patch('psycopg.connect') as mock_connect:
|
with patch("psycopg.connect") as mock_connect:
|
||||||
# Setup mock PostgreSQL connection
|
# Setup mock PostgreSQL connection
|
||||||
pg_mock = PostgreSQLMockInstance()
|
pg_mock = PostgreSQLMockInstance()
|
||||||
mock_connect.return_value = pg_mock.connect()
|
mock_connect.return_value = pg_mock.connect()
|
||||||
manager = checkpoint.ChatStreamManager(
|
manager = checkpoint.ChatStreamManager(
|
||||||
checkpoint_saver=True,
|
checkpoint_saver=True,
|
||||||
db_uri=POSTGRES_URL,
|
db_uri=POSTGRES_URL,
|
||||||
)
|
)
|
||||||
assert manager.postgres_conn is not None
|
assert manager.postgres_conn is not None
|
||||||
assert manager.mongo_client is None
|
assert manager.mongo_client is None
|
||||||
|
|
||||||
|
|
||||||
def test_with_local_mongo_db():
|
def test_with_local_mongo_db():
|
||||||
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
|
"""Ensure the ChatStreamManager can be initialized with a local MongoDB."""
|
||||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||||
# Setup mongomock
|
# Setup mongomock
|
||||||
mock_client = mongomock.MongoClient()
|
mock_client = mongomock.MongoClient()
|
||||||
mock_mongo_client.return_value = mock_client
|
mock_mongo_client.return_value = mock_client
|
||||||
|
|
||||||
manager = checkpoint.ChatStreamManager(
|
manager = checkpoint.ChatStreamManager(
|
||||||
checkpoint_saver=True,
|
checkpoint_saver=True,
|
||||||
db_uri=MONGO_URL,
|
db_uri=MONGO_URL,
|
||||||
@@ -80,7 +84,7 @@ def test_process_stream_partial_buffer_postgres(monkeypatch):
|
|||||||
|
|
||||||
def test_process_stream_partial_buffer_mongo():
|
def test_process_stream_partial_buffer_mongo():
|
||||||
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
|
"""Partial chunks should be buffered; Use mongomock instead of real MongoDB."""
|
||||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||||
# Setup mongomock
|
# Setup mongomock
|
||||||
mock_client = mongomock.MongoClient()
|
mock_client = mongomock.MongoClient()
|
||||||
mock_mongo_client.return_value = mock_client
|
mock_mongo_client.return_value = mock_client
|
||||||
@@ -96,7 +100,10 @@ def test_process_stream_partial_buffer_mongo():
|
|||||||
values = [it.dict()["value"] for it in items]
|
values = [it.dict()["value"] for it in items]
|
||||||
assert "hello" in values
|
assert "hello" in values
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not has_real_db_connection(), reason="PostgreSQL Server is not available"
|
||||||
|
)
|
||||||
def test_persist_postgresql_local_db():
|
def test_persist_postgresql_local_db():
|
||||||
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
|
"""Ensure that the ChatStreamManager can persist to a local PostgreSQL DB."""
|
||||||
manager = checkpoint.ChatStreamManager(
|
manager = checkpoint.ChatStreamManager(
|
||||||
@@ -116,7 +123,9 @@ def test_persist_postgresql_local_db():
|
|||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_real_db_connection(), reason="PostgreSQL Server is not available")
|
@pytest.mark.skipif(
|
||||||
|
not has_real_db_connection(), reason="PostgreSQL Server is not available"
|
||||||
|
)
|
||||||
def test_persist_postgresql_called_with_aggregated_chunks():
|
def test_persist_postgresql_called_with_aggregated_chunks():
|
||||||
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
|
"""On 'stop', aggregated chunks should be passed to PostgreSQL persist method."""
|
||||||
manager = checkpoint.ChatStreamManager(
|
manager = checkpoint.ChatStreamManager(
|
||||||
@@ -151,40 +160,42 @@ def test_persist_not_attempted_when_saver_disabled():
|
|||||||
|
|
||||||
def test_persist_mongodb_local_db():
|
def test_persist_mongodb_local_db():
|
||||||
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
|
"""Ensure that the ChatStreamManager can persist to a mocked MongoDB."""
|
||||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||||
# Setup mongomock
|
# Setup mongomock
|
||||||
mock_client = mongomock.MongoClient()
|
mock_client = mongomock.MongoClient()
|
||||||
mock_mongo_client.return_value = mock_client
|
mock_mongo_client.return_value = mock_client
|
||||||
|
|
||||||
manager = checkpoint.ChatStreamManager(
|
manager = checkpoint.ChatStreamManager(
|
||||||
checkpoint_saver=True,
|
checkpoint_saver=True,
|
||||||
db_uri=MONGO_URL,
|
db_uri=MONGO_URL,
|
||||||
)
|
)
|
||||||
assert manager.mongo_db is not None
|
assert manager.mongo_db is not None
|
||||||
assert manager.postgres_conn is None
|
assert manager.postgres_conn is None
|
||||||
|
|
||||||
# Simulate a message to persist
|
# Simulate a message to persist
|
||||||
thread_id = "test_thread"
|
thread_id = "test_thread"
|
||||||
messages = ["This is a test message."]
|
messages = ["This is a test message."]
|
||||||
result = manager._persist_to_mongodb(thread_id, messages)
|
result = manager._persist_to_mongodb(thread_id, messages)
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify data was persisted in mock
|
# Verify data was persisted in mock
|
||||||
collection = manager.mongo_db.chat_streams
|
collection = manager.mongo_db.chat_streams
|
||||||
doc = collection.find_one({"thread_id": thread_id})
|
doc = collection.find_one({"thread_id": thread_id})
|
||||||
assert doc is not None
|
assert doc is not None
|
||||||
assert doc["messages"] == messages
|
assert doc["messages"] == messages
|
||||||
|
|
||||||
# Simulate a message with existing thread
|
# Simulate a message with existing thread
|
||||||
result = manager._persist_to_mongodb(thread_id, ["Another message."])
|
result = manager._persist_to_mongodb(thread_id, ["Another message."])
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
# Verify update worked
|
# Verify update worked
|
||||||
doc = collection.find_one({"thread_id": thread_id})
|
doc = collection.find_one({"thread_id": thread_id})
|
||||||
assert doc["messages"] == ["Another message."]
|
assert doc["messages"] == ["Another message."]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not has_real_db_connection(), reason="MongoDB server is not available")
|
@pytest.mark.skipif(
|
||||||
|
not has_real_db_connection(), reason="MongoDB server is not available"
|
||||||
|
)
|
||||||
def test_persist_mongodb_called_with_aggregated_chunks():
|
def test_persist_mongodb_called_with_aggregated_chunks():
|
||||||
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
|
"""On 'stop', aggregated chunks should be passed to MongoDB persist method."""
|
||||||
|
|
||||||
@@ -239,11 +250,11 @@ def test_unsupported_db_uri_scheme():
|
|||||||
|
|
||||||
def test_process_stream_with_interrupt_finish_reason():
|
def test_process_stream_with_interrupt_finish_reason():
|
||||||
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
|
"""Test that 'interrupt' finish_reason triggers persistence like 'stop'."""
|
||||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||||
# Setup mongomock
|
# Setup mongomock
|
||||||
mock_client = mongomock.MongoClient()
|
mock_client = mongomock.MongoClient()
|
||||||
mock_mongo_client.return_value = mock_client
|
mock_mongo_client.return_value = mock_client
|
||||||
|
|
||||||
manager = checkpoint.ChatStreamManager(
|
manager = checkpoint.ChatStreamManager(
|
||||||
checkpoint_saver=True,
|
checkpoint_saver=True,
|
||||||
db_uri=MONGO_URL,
|
db_uri=MONGO_URL,
|
||||||
@@ -263,7 +274,7 @@ def test_process_stream_with_interrupt_finish_reason():
|
|||||||
)
|
)
|
||||||
is True
|
is True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify persistence occurred
|
# Verify persistence occurred
|
||||||
collection = manager.mongo_db.chat_streams
|
collection = manager.mongo_db.chat_streams
|
||||||
doc = collection.find_one({"thread_id": "int_test"})
|
doc = collection.find_one({"thread_id": "int_test"})
|
||||||
@@ -395,7 +406,7 @@ def test_multiple_threads_isolation():
|
|||||||
|
|
||||||
def test_mongodb_insert_and_update_paths():
|
def test_mongodb_insert_and_update_paths():
|
||||||
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
|
"""Exercise MongoDB insert, update, and exception branches using mongomock."""
|
||||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||||
# Setup mongomock
|
# Setup mongomock
|
||||||
mock_client = mongomock.MongoClient()
|
mock_client = mongomock.MongoClient()
|
||||||
mock_mongo_client.return_value = mock_client
|
mock_mongo_client.return_value = mock_client
|
||||||
@@ -404,7 +415,7 @@ def test_mongodb_insert_and_update_paths():
|
|||||||
|
|
||||||
# Insert success (new thread)
|
# Insert success (new thread)
|
||||||
assert manager._persist_to_mongodb("th1", ["message1"]) is True
|
assert manager._persist_to_mongodb("th1", ["message1"]) is True
|
||||||
|
|
||||||
# Verify insert worked
|
# Verify insert worked
|
||||||
collection = manager.mongo_db.chat_streams
|
collection = manager.mongo_db.chat_streams
|
||||||
doc = collection.find_one({"thread_id": "th1"})
|
doc = collection.find_one({"thread_id": "th1"})
|
||||||
@@ -413,7 +424,7 @@ def test_mongodb_insert_and_update_paths():
|
|||||||
|
|
||||||
# Update success (existing thread)
|
# Update success (existing thread)
|
||||||
assert manager._persist_to_mongodb("th1", ["message2"]) is True
|
assert manager._persist_to_mongodb("th1", ["message2"]) is True
|
||||||
|
|
||||||
# Verify update worked
|
# Verify update worked
|
||||||
doc = collection.find_one({"thread_id": "th1"})
|
doc = collection.find_one({"thread_id": "th1"})
|
||||||
assert doc["messages"] == ["message2"]
|
assert doc["messages"] == ["message2"]
|
||||||
@@ -421,9 +432,9 @@ def test_mongodb_insert_and_update_paths():
|
|||||||
# Test error case by mocking collection methods
|
# Test error case by mocking collection methods
|
||||||
original_find_one = collection.find_one
|
original_find_one = collection.find_one
|
||||||
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
|
collection.find_one = MagicMock(side_effect=RuntimeError("Database error"))
|
||||||
|
|
||||||
assert manager._persist_to_mongodb("th2", ["message"]) is False
|
assert manager._persist_to_mongodb("th2", ["message"]) is False
|
||||||
|
|
||||||
# Restore original method
|
# Restore original method
|
||||||
collection.find_one = original_find_one
|
collection.find_one = original_find_one
|
||||||
|
|
||||||
@@ -591,19 +602,19 @@ def test_context_manager_calls_close(monkeypatch):
|
|||||||
|
|
||||||
def test_init_mongodb_success_and_failure(monkeypatch):
|
def test_init_mongodb_success_and_failure(monkeypatch):
|
||||||
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
|
"""MongoDB init should succeed with mongomock and fail gracefully with errors."""
|
||||||
|
|
||||||
# Success path with mongomock
|
# Success path with mongomock
|
||||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||||
mock_client = mongomock.MongoClient()
|
mock_client = mongomock.MongoClient()
|
||||||
mock_mongo_client.return_value = mock_client
|
mock_mongo_client.return_value = mock_client
|
||||||
|
|
||||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||||
assert manager.mongo_db is not None
|
assert manager.mongo_db is not None
|
||||||
|
|
||||||
# Failure path
|
# Failure path
|
||||||
with patch('src.graph.checkpoint.MongoClient') as mock_mongo_client:
|
with patch("src.graph.checkpoint.MongoClient") as mock_mongo_client:
|
||||||
mock_mongo_client.side_effect = RuntimeError("Connection failed")
|
mock_mongo_client.side_effect = RuntimeError("Connection failed")
|
||||||
|
|
||||||
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
manager = checkpoint.ChatStreamManager(checkpoint_saver=True, db_uri=MONGO_URL)
|
||||||
# Should have no mongo_db set on failure
|
# Should have no mongo_db set on failure
|
||||||
assert getattr(manager, "mongo_db", None) is None
|
assert getattr(manager, "mongo_db", None) is None
|
||||||
|
|||||||
@@ -2,23 +2,21 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
SystemMessageChunk,
|
SystemMessageChunk,
|
||||||
FunctionMessageChunk,
|
|
||||||
ToolMessageChunk,
|
ToolMessageChunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.llms import llm as llm_module
|
from src.llms import llm as llm_module
|
||||||
from langchain_core.messages import ChatMessageChunk
|
|
||||||
from src.llms.providers import dashscope as dashscope_module
|
from src.llms.providers import dashscope as dashscope_module
|
||||||
|
|
||||||
from src.llms.providers.dashscope import (
|
from src.llms.providers.dashscope import (
|
||||||
ChatDashscope,
|
ChatDashscope,
|
||||||
_convert_delta_to_message_chunk,
|
|
||||||
_convert_chunk_to_generation_chunk,
|
_convert_chunk_to_generation_chunk,
|
||||||
|
_convert_delta_to_message_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
154
tests/unit/rag/test_dify.py
Normal file
154
tests/unit/rag/test_dify.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.rag.dify import DifyProvider, parse_uri
|
||||||
|
|
||||||
|
|
||||||
|
# Dummy classes to mock dependencies
|
||||||
|
class DummyResource:
|
||||||
|
def __init__(self, uri, title="", description=""):
|
||||||
|
self.uri = uri
|
||||||
|
self.title = title
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
|
||||||
|
class DummyChunk:
|
||||||
|
def __init__(self, content, similarity):
|
||||||
|
self.content = content
|
||||||
|
self.similarity = similarity
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDocument:
|
||||||
|
def __init__(self, id, title, chunks=None):
|
||||||
|
self.id = id
|
||||||
|
self.title = title
|
||||||
|
self.chunks = chunks or []
|
||||||
|
|
||||||
|
|
||||||
|
# Patch imports in dify.py to use dummy classes
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_imports(monkeypatch):
|
||||||
|
import src.rag.dify as dify
|
||||||
|
|
||||||
|
dify.Resource = DummyResource
|
||||||
|
dify.Chunk = DummyChunk
|
||||||
|
dify.Document = DummyDocument
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_uri_valid():
|
||||||
|
uri = "rag://dataset/123#abc"
|
||||||
|
dataset_id, document_id = parse_uri(uri)
|
||||||
|
assert dataset_id == "123"
|
||||||
|
assert document_id == "abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_uri_invalid():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
parse_uri("http://dataset/123#abc")
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_env_vars(monkeypatch):
|
||||||
|
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||||
|
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||||
|
provider = DifyProvider()
|
||||||
|
assert provider.api_url == "http://api"
|
||||||
|
assert provider.api_key == "key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_missing_env(monkeypatch):
|
||||||
|
monkeypatch.delenv("DIFY_API_URL", raising=False)
|
||||||
|
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DifyProvider()
|
||||||
|
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||||
|
monkeypatch.delenv("DIFY_API_KEY", raising=False)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
DifyProvider()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("src.rag.dify.requests.post")
|
||||||
|
def test_query_relevant_documents_success(mock_post, monkeypatch):
|
||||||
|
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||||
|
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||||
|
provider = DifyProvider()
|
||||||
|
resource = DummyResource("rag://dataset/123#doc456")
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"records": [
|
||||||
|
{
|
||||||
|
"segment": {
|
||||||
|
"content": "chunk text",
|
||||||
|
"document": {
|
||||||
|
"id": "doc456",
|
||||||
|
"name": "Doc Title",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"score": 0.9,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
docs = provider.query_relevant_documents("query", [resource])
|
||||||
|
assert len(docs) == 1
|
||||||
|
assert docs[0].id == "doc456"
|
||||||
|
assert docs[0].title == "Doc Title"
|
||||||
|
assert len(docs[0].chunks) == 1
|
||||||
|
assert docs[0].chunks[0].content == "chunk text"
|
||||||
|
assert docs[0].chunks[0].similarity == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
@patch("src.rag.dify.requests.post")
|
||||||
|
def test_query_relevant_documents_error(mock_post, monkeypatch):
|
||||||
|
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||||
|
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||||
|
provider = DifyProvider()
|
||||||
|
resource = DummyResource("rag://dataset/123#doc456")
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 400
|
||||||
|
mock_response.text = "error"
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
provider.query_relevant_documents("query", [resource])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("src.rag.dify.requests.get")
|
||||||
|
def test_list_resources_success(mock_get, monkeypatch):
|
||||||
|
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||||
|
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||||
|
provider = DifyProvider()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"data": [
|
||||||
|
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||||
|
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
resources = provider.list_resources()
|
||||||
|
assert len(resources) == 2
|
||||||
|
assert resources[0].uri == "rag://dataset/123"
|
||||||
|
assert resources[0].title == "Dataset1"
|
||||||
|
assert resources[0].description == "desc1"
|
||||||
|
assert resources[1].uri == "rag://dataset/456"
|
||||||
|
assert resources[1].title == "Dataset2"
|
||||||
|
assert resources[1].description == "desc2"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("src.rag.dify.requests.get")
|
||||||
|
def test_list_resources_error(mock_get, monkeypatch):
|
||||||
|
monkeypatch.setenv("DIFY_API_URL", "http://api")
|
||||||
|
monkeypatch.setenv("DIFY_API_KEY", "key")
|
||||||
|
provider = DifyProvider()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
mock_response.text = "fail"
|
||||||
|
mock_get.return_value = mock_response
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
provider.list_resources()
|
||||||
@@ -2,9 +2,11 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from uuid import uuid4
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import src.rag.milvus as milvus_mod
|
import src.rag.milvus as milvus_mod
|
||||||
@@ -13,7 +15,6 @@ from src.rag.retriever import Resource
|
|||||||
|
|
||||||
|
|
||||||
class DummyEmbedding:
|
class DummyEmbedding:
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
@@ -369,9 +370,7 @@ def test_create_collection_lite(monkeypatch):
|
|||||||
def list_collections(self): # noqa: D401
|
def list_collections(self): # noqa: D401
|
||||||
return [] # empty triggers creation
|
return [] # empty triggers creation
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(self, collection_name, schema, index_params): # noqa: D401
|
||||||
self, collection_name, schema, index_params
|
|
||||||
): # noqa: D401
|
|
||||||
created["name"] = collection_name
|
created["name"] = collection_name
|
||||||
created["schema"] = schema
|
created["schema"] = schema
|
||||||
created["index"] = index_params
|
created["index"] = index_params
|
||||||
|
|||||||
Reference in New Issue
Block a user