feat: integrate VikingDB Knowledge Base into rag retrieving tool (#381)

Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
殷逸维
2025-07-03 10:06:42 +08:00
committed by GitHub
parent 5977b4a03e
commit be893eae2b
8 changed files with 814 additions and 3 deletions

View File

@@ -21,6 +21,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
class RAGProvider(enum.Enum):
RAGFLOW = "ragflow"
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
SELECTED_RAG_PROVIDER = os.getenv("RAG_PROVIDER")

View File

@@ -3,6 +3,15 @@
from .retriever import Retriever, Document, Resource, Chunk
from .ragflow import RAGFlowProvider
from .vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
from .builder import build_retriever
__all__ = [Retriever, Document, Resource, RAGFlowProvider, Chunk, build_retriever]
__all__ = [
Retriever,
Document,
Resource,
RAGFlowProvider,
VikingDBKnowledgeBaseProvider,
Chunk,
build_retriever,
]

View File

@@ -3,12 +3,15 @@
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
from src.rag.ragflow import RAGFlowProvider
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
from src.rag.retriever import Retriever
def build_retriever() -> Retriever | None:
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
return RAGFlowProvider()
elif SELECTED_RAG_PROVIDER == RAGProvider.VIKINGDB_KNOWLEDGE_BASE.value:
return VikingDBKnowledgeBaseProvider()
elif SELECTED_RAG_PROVIDER:
raise ValueError(f"Unsupported RAG provider: {SELECTED_RAG_PROVIDER}")
return None

View File

@@ -0,0 +1,208 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import requests
import json
from src.rag.retriever import Chunk, Document, Resource, Retriever
from urllib.parse import urlparse
from volcengine.auth.SignerV4 import SignerV4
from volcengine.base.Request import Request
from volcengine.Credentials import Credentials
class VikingDBKnowledgeBaseProvider(Retriever):
"""
VikingDBKnowledgeBaseProvider is a provider that uses VikingDB Knowledge base API to retrieve documents.
"""
api_url: str
api_ak: str
api_sk: str
retrieval_size: int = 10
def __init__(self):
api_url = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_URL")
if not api_url:
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_URL is not set")
self.api_url = api_url
api_ak = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_AK")
if not api_ak:
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_AK is not set")
self.api_ak = api_ak
api_sk = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_SK")
if not api_sk:
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_SK is not set")
self.api_sk = api_sk
retrieval_size = os.getenv("VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE")
if retrieval_size:
self.retrieval_size = int(retrieval_size)
def prepare_request(self, method, path, params=None, data=None, doseq=0):
"""
Prepare signed request using volcengine auth
"""
if params:
for key in params:
if (
type(params[key]) == int
or type(params[key]) == float
or type(params[key]) == bool
):
params[key] = str(params[key])
elif type(params[key]) == list:
if not doseq:
params[key] = ",".join(params[key])
r = Request()
r.set_shema("https")
r.set_method(method)
r.set_connection_timeout(10)
r.set_socket_timeout(10)
mheaders = {
"Accept": "application/json",
"Content-Type": "application/json",
}
r.set_headers(mheaders)
if params:
r.set_query(params)
r.set_path(path)
if data is not None:
r.set_body(json.dumps(data))
credentials = Credentials(self.api_ak, self.api_sk, "air", "cn-north-1")
SignerV4.sign(r, credentials)
return r
def query_relevant_documents(
self, query: str, resources: list[Resource] = []
) -> list[Document]:
"""
Query relevant documents from the knowledge base
"""
if not resources:
return []
all_documents = {}
for resource in resources:
resource_id, document_id = parse_uri(resource.uri)
request_params = {
"resource_id": resource_id,
"query": query,
"limit": self.retrieval_size,
"dense_weight": 0.5,
"pre_processing": {
"need_instruction": True,
"rewrite": False,
"return_token_usage": True,
},
"post_processing": {
"rerank_switch": True,
"chunk_diffusion_count": 0,
"chunk_group": True,
"get_attachment_link": True,
},
}
if document_id:
doc_filter = {"op": "must", "field": "doc_id", "conds": [document_id]}
query_param = {"doc_filter": doc_filter}
request_params["query_param"] = query_param
method = "POST"
path = "/api/knowledge/collection/search_knowledge"
info_req = self.prepare_request(
method=method, path=path, data=request_params
)
rsp = requests.request(
method=info_req.method,
url="http://{}{}".format(self.api_url, info_req.path),
headers=info_req.headers,
data=info_req.body,
)
try:
response = json.loads(rsp.text)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON response: {e}")
if response["code"] != 0:
raise ValueError(
f"Failed to query documents from resource: {response['message']}"
)
rsp_data = response.get("data", {})
if "result_list" not in rsp_data:
continue
result_list = rsp_data["result_list"]
for item in result_list:
doc_info = item.get("doc_info", {})
doc_id = doc_info.get("doc_id")
if not doc_id:
continue
if doc_id not in all_documents:
all_documents[doc_id] = Document(
id=doc_id, title=doc_info.get("doc_name"), chunks=[]
)
chunk = Chunk(
content=item.get("content", ""), similarity=item.get("score", 0.0)
)
all_documents[doc_id].chunks.append(chunk)
return list(all_documents.values())
def list_resources(self, query: str | None = None) -> list[Resource]:
"""
List resources (knowledge bases) from the knowledge base service
"""
method = "POST"
path = "/api/knowledge/collection/list"
info_req = self.prepare_request(method=method, path=path)
rsp = requests.request(
method=info_req.method,
url="http://{}{}".format(self.api_url, info_req.path),
headers=info_req.headers,
data=info_req.body,
)
try:
response = json.loads(rsp.text)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON response: {e}")
if response["code"] != 0:
raise Exception(f"Failed to list resources: {response["message"]}")
resources = []
rsp_data = response.get("data", {})
collection_list = rsp_data.get("collection_list", [])
for item in collection_list:
collection_name = item.get("collection_name", "")
description = item.get("description", "")
if query and query.lower() not in collection_name.lower():
continue
resource_id = item.get("resource_id", "")
resource = Resource(
uri=f"rag://dataset/{resource_id}",
title=collection_name,
description=description,
)
resources.append(resource)
return resources
def parse_uri(uri: str) -> tuple[str, str]:
parsed = urlparse(uri)
if parsed.scheme != "rag":
raise ValueError(f"Invalid URI: {uri}")
return parsed.path.split("/")[1], parsed.fragment