feat: support dify in rag module (#550)

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Chayton Bai
2025-09-16 20:30:45 +08:00
committed by GitHub
parent 5085bf8ee9
commit 7694bb5d72
19 changed files with 407 additions and 87 deletions

View File

@@ -8,9 +8,9 @@ from typing import Any, Optional
from langchain_core.runnables import RunnableConfig
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.config.report_style import ReportStyle
from src.rag.retriever import Resource
from src.config.loader import get_str_env, get_int_env, get_bool_env
logger = logging.getLogger(__name__)

View File

@@ -22,6 +22,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
class RAGProvider(enum.Enum):
DIFY = "dify"
RAGFLOW = "ragflow"
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
MOI = "moi"

View File

@@ -6,10 +6,12 @@ import logging
import uuid
from datetime import datetime
from typing import List, Optional, Tuple
import psycopg
from langgraph.store.memory import InMemoryStore
from psycopg.rows import dict_row
from pymongo import MongoClient
from langgraph.store.memory import InMemoryStore
from src.config.loader import get_bool_env, get_str_env

View File

@@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel
from langchain_deepseek import ChatDeepSeek
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from typing import get_args
from src.config import load_yaml_config
from src.config.agents import LLMType

View File

@@ -211,7 +211,6 @@ class ChatDashscope(ChatOpenAI):
and hasattr(response.choices[0], "message")
and hasattr(response.choices[0].message, "reasoning_content")
):
reasoning_content = response.choices[0].message.reasoning_content
if reasoning_content and chat_result.generations:
chat_result.generations[0].message.additional_kwargs[

View File

@@ -2,6 +2,7 @@
# SPDX-License-Identifier: MIT
from .builder import build_retriever
from .dify import DifyProvider
from .ragflow import RAGFlowProvider
from .moi import MOIProvider
from .retriever import Chunk, Document, Resource, Retriever
@@ -11,6 +12,7 @@ __all__ = [
Retriever,
Document,
Resource,
DifyProvider,
RAGFlowProvider,
MOIProvider,
VikingDBKnowledgeBaseProvider,

View File

@@ -2,14 +2,17 @@
# SPDX-License-Identifier: MIT
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
from src.rag.dify import DifyProvider
from src.rag.milvus import MilvusProvider
from src.rag.ragflow import RAGFlowProvider
from src.rag.moi import MOIProvider
from src.rag.retriever import Retriever
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
from src.rag.milvus import MilvusProvider
def build_retriever() -> Retriever | None:
if SELECTED_RAG_PROVIDER == RAGProvider.DIFY.value:
return DifyProvider()
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
return RAGFlowProvider()
elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value:

132
src/rag/dify.py Normal file
View File

@@ -0,0 +1,132 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
from urllib.parse import urlparse
import requests
from src.rag.retriever import Chunk, Document, Resource, Retriever
class DifyProvider(Retriever):
"""
DifyProvider is a provider that uses dify to retrieve documents.
"""
api_url: str
api_key: str
def __init__(self):
api_url = os.getenv("DIFY_API_URL")
if not api_url:
raise ValueError("DIFY_API_URL is not set")
self.api_url = api_url
api_key = os.getenv("DIFY_API_KEY")
if not api_key:
raise ValueError("DIFY_API_KEY is not set")
self.api_key = api_key
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
if not resources:
return []
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
all_documents = {}
for resource in resources:
dataset_id, _ = parse_uri(resource.uri)
payload = {
"query": query,
"retrieval_model": {
"search_method": "hybrid_search",
"reranking_enable": False,
"weights": {
"weight_type": "customized",
"keyword_setting": {"keyword_weight": 0.3},
"vector_setting": {"vector_weight": 0.7},
},
"top_k": 3,
"score_threshold_enabled": True,
"score_threshold": 0.5,
},
}
response = requests.post(
f"{self.api_url}/datasets/{dataset_id}/retrieve",
headers=headers,
json=payload,
)
if response.status_code != 200:
raise Exception(f"Failed to query documents: {response.text}")
result = response.json()
records = result.get("records", {})
for record in records:
segment = record.get("segment")
if not segment:
continue
document_info = segment.get("document")
if not document_info:
continue
doc_id = document_info.get("id")
doc_name = document_info.get("name")
if not doc_id or not doc_name:
continue
if doc_id not in all_documents:
all_documents[doc_id] = Document(
id=doc_id, title=doc_name, chunks=[]
)
chunk = Chunk(
content=segment.get("content", ""),
similarity=record.get("score", 0.0),
)
all_documents[doc_id].chunks.append(chunk)
return list(all_documents.values())
def list_resources(self, query: str | None = None) -> list[Resource]:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
params = {}
if query:
params["keyword"] = query
response = requests.get(
f"{self.api_url}/datasets", headers=headers, params=params
)
if response.status_code != 200:
raise Exception(f"Failed to list resources: {response.text}")
result = response.json()
resources = []
for item in result.get("data", []):
item = Resource(
uri=f"rag://dataset/{item.get('id')}",
title=item.get("name", ""),
description=item.get("description", ""),
)
resources.append(item)
return resources
def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment

View File

@@ -7,11 +7,12 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
from langchain_milvus.vectorstores import Milvus as LangchainMilvus
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
from langchain_openai import OpenAIEmbeddings
from openai import OpenAI
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
from src.config.loader import get_bool_env, get_int_env, get_str_env
from src.rag.retriever import Chunk, Document, Resource, Retriever
from src.config.loader import get_bool_env, get_str_env, get_int_env
logger = logging.getLogger(__name__)
@@ -466,7 +467,7 @@ class MilvusRetriever(Retriever):
resources.append(
Resource(
uri=r.get(self.url_field, "")
or f"milvus://{r.get(self.id_field,'')}",
or f"milvus://{r.get(self.id_field, '')}",
title=r.get(self.title_field, "")
or r.get(self.id_field, "Unnamed"),
description="Stored Milvus document",
@@ -476,21 +477,23 @@ class MilvusRetriever(Retriever):
# Use similarity_search_by_vector for lightweight listing.
# If a query is provided embed it; else use a zero vector.
docs: Iterable[Any] = self.client.similarity_search(
query, k=100, expr="source == 'examples'" # Limit to 100 results
query,
k=100,
expr="source == 'examples'", # Limit to 100 results
)
for d in docs:
meta = getattr(d, "metadata", {}) or {}
# check if the resource is in the list of resources
if resources and any(
r.uri == meta.get(self.url_field, "")
or r.uri == f"milvus://{meta.get(self.id_field,'')}"
or r.uri == f"milvus://{meta.get(self.id_field, '')}"
for r in resources
):
continue
resources.append(
Resource(
uri=meta.get(self.url_field, "")
or f"milvus://{meta.get(self.id_field,'')}",
or f"milvus://{meta.get(self.id_field, '')}",
title=meta.get(self.title_field, "")
or meta.get(self.id_field, "Unnamed"),
description="Stored Milvus document",

View File

@@ -11,10 +11,10 @@ from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
from langgraph.types import Command
from langgraph.store.memory import InMemoryStore
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.memory import InMemoryStore
from langgraph.types import Command
from psycopg_pool import AsyncConnectionPool
from src.config.configuration import get_recursion_limit
@@ -22,6 +22,7 @@ from src.config.loader import get_bool_env, get_str_env
from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
from src.graph.checkpoint import chat_stream_message
from src.llms.llm import get_configured_llm_models
from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.ppt.graph.builder import build_graph as build_ppt_graph
@@ -47,7 +48,6 @@ from src.server.rag_request import (
RAGResourcesResponse,
)
from src.tools import VolcengineTTS
from src.graph.checkpoint import chat_stream_message
from src.utils.json_utils import sanitize_args
logger = logging.getLogger(__name__)

View File

@@ -97,7 +97,8 @@ async def load_mcp_tools(
)
return await _get_tools_from_client_session(
sse_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds
sse_client(url=url, headers=headers, timeout=timeout_seconds),
timeout_seconds,
)
elif server_type == "streamable_http":
@@ -107,7 +108,10 @@ async def load_mcp_tools(
)
return await _get_tools_from_client_session(
streamablehttp_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds,
streamablehttp_client(
url=url, headers=headers, timeout=timeout_seconds
),
timeout_seconds,
)
else:

View File

@@ -4,6 +4,7 @@
import json
import logging
from typing import Any
import json_repair
logger = logging.getLogger(__name__)