mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-12 18:14:46 +08:00
feat: support dify in rag module (#550)
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -8,9 +8,9 @@ from typing import Any, Optional
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from src.config.loader import get_bool_env, get_int_env, get_str_env
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.rag.retriever import Resource
|
||||
from src.config.loader import get_str_env, get_int_env, get_bool_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
|
||||
|
||||
|
||||
class RAGProvider(enum.Enum):
|
||||
DIFY = "dify"
|
||||
RAGFLOW = "ragflow"
|
||||
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
|
||||
MOI = "moi"
|
||||
|
||||
@@ -6,10 +6,12 @@ import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import psycopg
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from psycopg.rows import dict_row
|
||||
from pymongo import MongoClient
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from src.config.loader import get_bool_env, get_str_env
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from langchain_core.language_models import BaseChatModel
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
||||
from typing import get_args
|
||||
|
||||
from src.config import load_yaml_config
|
||||
from src.config.agents import LLMType
|
||||
|
||||
@@ -211,7 +211,6 @@ class ChatDashscope(ChatOpenAI):
|
||||
and hasattr(response.choices[0], "message")
|
||||
and hasattr(response.choices[0].message, "reasoning_content")
|
||||
):
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
if reasoning_content and chat_result.generations:
|
||||
chat_result.generations[0].message.additional_kwargs[
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .builder import build_retriever
|
||||
from .dify import DifyProvider
|
||||
from .ragflow import RAGFlowProvider
|
||||
from .moi import MOIProvider
|
||||
from .retriever import Chunk, Document, Resource, Retriever
|
||||
@@ -11,6 +12,7 @@ __all__ = [
|
||||
Retriever,
|
||||
Document,
|
||||
Resource,
|
||||
DifyProvider,
|
||||
RAGFlowProvider,
|
||||
MOIProvider,
|
||||
VikingDBKnowledgeBaseProvider,
|
||||
|
||||
@@ -2,14 +2,17 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
|
||||
from src.rag.dify import DifyProvider
|
||||
from src.rag.milvus import MilvusProvider
|
||||
from src.rag.ragflow import RAGFlowProvider
|
||||
from src.rag.moi import MOIProvider
|
||||
from src.rag.retriever import Retriever
|
||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||
from src.rag.milvus import MilvusProvider
|
||||
|
||||
|
||||
def build_retriever() -> Retriever | None:
|
||||
if SELECTED_RAG_PROVIDER == RAGProvider.DIFY.value:
|
||||
return DifyProvider()
|
||||
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
|
||||
return RAGFlowProvider()
|
||||
elif SELECTED_RAG_PROVIDER == RAGProvider.MOI.value:
|
||||
|
||||
132
src/rag/dify.py
Normal file
132
src/rag/dify.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
|
||||
|
||||
class DifyProvider(Retriever):
|
||||
"""
|
||||
DifyProvider is a provider that uses dify to retrieve documents.
|
||||
"""
|
||||
|
||||
api_url: str
|
||||
api_key: str
|
||||
|
||||
def __init__(self):
|
||||
api_url = os.getenv("DIFY_API_URL")
|
||||
if not api_url:
|
||||
raise ValueError("DIFY_API_URL is not set")
|
||||
self.api_url = api_url
|
||||
|
||||
api_key = os.getenv("DIFY_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DIFY_API_KEY is not set")
|
||||
self.api_key = api_key
|
||||
|
||||
def query_relevant_documents(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
if not resources:
|
||||
return []
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
all_documents = {}
|
||||
for resource in resources:
|
||||
dataset_id, _ = parse_uri(resource.uri)
|
||||
payload = {
|
||||
"query": query,
|
||||
"retrieval_model": {
|
||||
"search_method": "hybrid_search",
|
||||
"reranking_enable": False,
|
||||
"weights": {
|
||||
"weight_type": "customized",
|
||||
"keyword_setting": {"keyword_weight": 0.3},
|
||||
"vector_setting": {"vector_weight": 0.7},
|
||||
},
|
||||
"top_k": 3,
|
||||
"score_threshold_enabled": True,
|
||||
"score_threshold": 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/datasets/{dataset_id}/retrieve",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to query documents: {response.text}")
|
||||
|
||||
result = response.json()
|
||||
records = result.get("records", {})
|
||||
for record in records:
|
||||
segment = record.get("segment")
|
||||
if not segment:
|
||||
continue
|
||||
document_info = segment.get("document")
|
||||
if not document_info:
|
||||
continue
|
||||
doc_id = document_info.get("id")
|
||||
doc_name = document_info.get("name")
|
||||
if not doc_id or not doc_name:
|
||||
continue
|
||||
|
||||
if doc_id not in all_documents:
|
||||
all_documents[doc_id] = Document(
|
||||
id=doc_id, title=doc_name, chunks=[]
|
||||
)
|
||||
|
||||
chunk = Chunk(
|
||||
content=segment.get("content", ""),
|
||||
similarity=record.get("score", 0.0),
|
||||
)
|
||||
all_documents[doc_id].chunks.append(chunk)
|
||||
|
||||
return list(all_documents.values())
|
||||
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
params = {}
|
||||
if query:
|
||||
params["keyword"] = query
|
||||
|
||||
response = requests.get(
|
||||
f"{self.api_url}/datasets", headers=headers, params=params
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to list resources: {response.text}")
|
||||
|
||||
result = response.json()
|
||||
resources = []
|
||||
|
||||
for item in result.get("data", []):
|
||||
item = Resource(
|
||||
uri=f"rag://dataset/{item.get('id')}",
|
||||
title=item.get("name", ""),
|
||||
description=item.get("description", ""),
|
||||
)
|
||||
resources.append(item)
|
||||
|
||||
return resources
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> tuple[str, str]:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != "rag":
|
||||
raise ValueError(f"Invalid URI: {uri}")
|
||||
return parsed.path.split("/")[1], parsed.fragment
|
||||
@@ -7,11 +7,12 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
|
||||
|
||||
from langchain_milvus.vectorstores import Milvus as LangchainMilvus
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from openai import OpenAI
|
||||
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient
|
||||
|
||||
from src.config.loader import get_bool_env, get_int_env, get_str_env
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
from src.config.loader import get_bool_env, get_str_env, get_int_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -466,7 +467,7 @@ class MilvusRetriever(Retriever):
|
||||
resources.append(
|
||||
Resource(
|
||||
uri=r.get(self.url_field, "")
|
||||
or f"milvus://{r.get(self.id_field,'')}",
|
||||
or f"milvus://{r.get(self.id_field, '')}",
|
||||
title=r.get(self.title_field, "")
|
||||
or r.get(self.id_field, "Unnamed"),
|
||||
description="Stored Milvus document",
|
||||
@@ -476,21 +477,23 @@ class MilvusRetriever(Retriever):
|
||||
# Use similarity_search_by_vector for lightweight listing.
|
||||
# If a query is provided embed it; else use a zero vector.
|
||||
docs: Iterable[Any] = self.client.similarity_search(
|
||||
query, k=100, expr="source == 'examples'" # Limit to 100 results
|
||||
query,
|
||||
k=100,
|
||||
expr="source == 'examples'", # Limit to 100 results
|
||||
)
|
||||
for d in docs:
|
||||
meta = getattr(d, "metadata", {}) or {}
|
||||
# check if the resource is in the list of resources
|
||||
if resources and any(
|
||||
r.uri == meta.get(self.url_field, "")
|
||||
or r.uri == f"milvus://{meta.get(self.id_field,'')}"
|
||||
or r.uri == f"milvus://{meta.get(self.id_field, '')}"
|
||||
for r in resources
|
||||
):
|
||||
continue
|
||||
resources.append(
|
||||
Resource(
|
||||
uri=meta.get(self.url_field, "")
|
||||
or f"milvus://{meta.get(self.id_field,'')}",
|
||||
or f"milvus://{meta.get(self.id_field, '')}",
|
||||
title=meta.get(self.title_field, "")
|
||||
or meta.get(self.id_field, "Unnamed"),
|
||||
description="Stored Milvus document",
|
||||
|
||||
@@ -11,10 +11,10 @@ from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
from langgraph.types import Command
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from src.config.configuration import get_recursion_limit
|
||||
@@ -22,6 +22,7 @@ from src.config.loader import get_bool_env, get_str_env
|
||||
from src.config.report_style import ReportStyle
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER
|
||||
from src.graph.builder import build_graph_with_memory
|
||||
from src.graph.checkpoint import chat_stream_message
|
||||
from src.llms.llm import get_configured_llm_models
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||
@@ -47,7 +48,6 @@ from src.server.rag_request import (
|
||||
RAGResourcesResponse,
|
||||
)
|
||||
from src.tools import VolcengineTTS
|
||||
from src.graph.checkpoint import chat_stream_message
|
||||
from src.utils.json_utils import sanitize_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,7 +97,8 @@ async def load_mcp_tools(
|
||||
)
|
||||
|
||||
return await _get_tools_from_client_session(
|
||||
sse_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds
|
||||
sse_client(url=url, headers=headers, timeout=timeout_seconds),
|
||||
timeout_seconds,
|
||||
)
|
||||
|
||||
elif server_type == "streamable_http":
|
||||
@@ -107,7 +108,10 @@ async def load_mcp_tools(
|
||||
)
|
||||
|
||||
return await _get_tools_from_client_session(
|
||||
streamablehttp_client(url=url, headers=headers, timeout=timeout_seconds), timeout_seconds,
|
||||
streamablehttp_client(
|
||||
url=url, headers=headers, timeout=timeout_seconds
|
||||
),
|
||||
timeout_seconds,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user