mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-17 03:34:45 +08:00
remove volengine package (#464)
This commit is contained in:
@@ -4,11 +4,12 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import hashlib
|
||||
import hmac
|
||||
import urllib.parse
|
||||
from datetime import datetime
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
from urllib.parse import urlparse
|
||||
from volcengine.auth.SignerV4 import SignerV4
|
||||
from volcengine.base.Request import Request
|
||||
from volcengine.Credentials import Credentials
|
||||
|
||||
|
||||
class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
@@ -20,6 +21,8 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
api_ak: str
|
||||
api_sk: str
|
||||
retrieval_size: int = 10
|
||||
region: str = "cn-north-1"
|
||||
service: str = "air"
|
||||
|
||||
def __init__(self):
|
||||
api_url = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_URL")
|
||||
@@ -41,41 +44,137 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
if retrieval_size:
|
||||
self.retrieval_size = int(retrieval_size)
|
||||
|
||||
def prepare_request(self, method, path, params=None, data=None, doseq=0):
|
||||
"""
|
||||
Prepare signed request using volcengine auth
|
||||
"""
|
||||
if params:
|
||||
for key in params:
|
||||
if (
|
||||
type(params[key]) is int
|
||||
or type(params[key]) is float
|
||||
or type(params[key]) is bool
|
||||
):
|
||||
params[key] = str(params[key])
|
||||
elif type(params[key]) is list:
|
||||
if not doseq:
|
||||
params[key] = ",".join(params[key])
|
||||
# 设置region,如果需要可以从环境变量获取
|
||||
region = os.getenv("VIKINGDB_KNOWLEDGE_BASE_REGION", "cn-north-1")
|
||||
self.region = region
|
||||
|
||||
r = Request()
|
||||
r.set_shema("https")
|
||||
r.set_method(method)
|
||||
r.set_connection_timeout(10)
|
||||
r.set_socket_timeout(10)
|
||||
mheaders = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
r.set_headers(mheaders)
|
||||
if params:
|
||||
r.set_query(params)
|
||||
r.set_path(path)
|
||||
if data is not None:
|
||||
r.set_body(json.dumps(data))
|
||||
def _hmac_sha256(self, key: bytes, content: str) -> bytes:
|
||||
return hmac.new(key, content.encode("utf-8"), hashlib.sha256).digest()
|
||||
|
||||
credentials = Credentials(self.api_ak, self.api_sk, "air", "cn-north-1")
|
||||
SignerV4.sign(r, credentials)
|
||||
return r
|
||||
def _hash_sha256(self, data: bytes) -> bytes:
|
||||
return hashlib.sha256(data).digest()
|
||||
|
||||
def _get_signed_key(
|
||||
self, secret_key: str, date: str, region: str, service: str
|
||||
) -> bytes:
|
||||
k_date = self._hmac_sha256(secret_key.encode("utf-8"), date)
|
||||
k_region = self._hmac_sha256(k_date, region)
|
||||
k_service = self._hmac_sha256(k_region, service)
|
||||
k_signing = self._hmac_sha256(k_service, "request")
|
||||
return k_signing
|
||||
|
||||
def _create_canonical_request(
|
||||
self, method: str, path: str, query_params: dict, headers: dict, payload: bytes
|
||||
) -> str:
|
||||
canonical_method = method.upper()
|
||||
canonical_uri = path if path else "/"
|
||||
if query_params:
|
||||
encoded_params = []
|
||||
for key in sorted(query_params.keys()):
|
||||
value = query_params[key]
|
||||
encoded_key = urllib.parse.quote(str(key), safe="")
|
||||
encoded_value = urllib.parse.quote(str(value), safe="")
|
||||
encoded_params.append(f"{encoded_key}={encoded_value}")
|
||||
canonical_query_string = "&".join(encoded_params)
|
||||
else:
|
||||
canonical_query_string = ""
|
||||
|
||||
canonical_headers_list = []
|
||||
signed_headers_list = []
|
||||
for header_name in sorted(headers.keys(), key=str.lower):
|
||||
header_name_lower = header_name.lower()
|
||||
header_value = str(headers[header_name]).strip()
|
||||
canonical_headers_list.append(f"{header_name_lower}:{header_value}")
|
||||
signed_headers_list.append(header_name_lower)
|
||||
|
||||
canonical_headers = "\n".join(canonical_headers_list) + "\n"
|
||||
signed_headers = ";".join(signed_headers_list)
|
||||
|
||||
payload_hash = self._hash_sha256(payload).hex()
|
||||
|
||||
canonical_request = "\n".join(
|
||||
[
|
||||
canonical_method,
|
||||
canonical_uri,
|
||||
canonical_query_string,
|
||||
canonical_headers,
|
||||
signed_headers,
|
||||
payload_hash,
|
||||
]
|
||||
)
|
||||
|
||||
return canonical_request, signed_headers
|
||||
|
||||
def _create_signature(
|
||||
self, method: str, path: str, query_params: dict, headers: dict, payload: bytes
|
||||
) -> str:
|
||||
now = datetime.utcnow()
|
||||
date_stamp = now.strftime("%Y%m%dT%H%M%SZ")
|
||||
auth_date = date_stamp[:8]
|
||||
|
||||
headers["X-Date"] = date_stamp
|
||||
headers["Host"] = self.api_url.replace("https://", "").replace("http://", "")
|
||||
headers["X-Content-Sha256"] = self._hash_sha256(payload).hex()
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
canonical_request, signed_headers = self._create_canonical_request(
|
||||
method, path, query_params, headers, payload
|
||||
)
|
||||
|
||||
algorithm = "HMAC-SHA256"
|
||||
credential_scope = f"{auth_date}/{self.region}/{self.service}/request"
|
||||
canonical_request_hash = self._hash_sha256(
|
||||
canonical_request.encode("utf-8")
|
||||
).hex()
|
||||
|
||||
string_to_sign = "\n".join(
|
||||
[algorithm, date_stamp, credential_scope, canonical_request_hash]
|
||||
)
|
||||
|
||||
signing_key = self._get_signed_key(
|
||||
self.api_sk, auth_date, self.region, self.service
|
||||
)
|
||||
signature = hmac.new(
|
||||
signing_key, string_to_sign.encode("utf-8"), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
authorization = (
|
||||
f"{algorithm} "
|
||||
f"Credential={self.api_ak}/{credential_scope}, "
|
||||
f"SignedHeaders={signed_headers}, "
|
||||
f"Signature={signature}"
|
||||
)
|
||||
|
||||
headers["Authorization"] = authorization
|
||||
|
||||
return headers
|
||||
|
||||
def _make_signed_request(
|
||||
self, method: str, path: str, params: dict = None, data: dict = None
|
||||
):
|
||||
if data is None:
|
||||
payload = b""
|
||||
else:
|
||||
payload = json.dumps(data).encode("utf-8")
|
||||
|
||||
if params is None:
|
||||
params = {}
|
||||
|
||||
url = f"https://{self.api_url}{path}"
|
||||
headers = {}
|
||||
signed_headers = self._create_signature(method, path, params, headers, payload)
|
||||
try:
|
||||
response = requests.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=signed_headers,
|
||||
params=params,
|
||||
data=payload if payload else None,
|
||||
timeout=30,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise ValueError(f"Request failed: {e}")
|
||||
|
||||
def query_relevant_documents(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
@@ -111,29 +210,24 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
query_param = {"doc_filter": doc_filter}
|
||||
request_params["query_param"] = query_param
|
||||
|
||||
method = "POST"
|
||||
path = "/api/knowledge/collection/search_knowledge"
|
||||
info_req = self.prepare_request(
|
||||
method=method, path=path, data=request_params
|
||||
)
|
||||
rsp = requests.request(
|
||||
method=info_req.method,
|
||||
url="http://{}{}".format(self.api_url, info_req.path),
|
||||
headers=info_req.headers,
|
||||
data=info_req.body,
|
||||
|
||||
# 使用新的签名请求方法
|
||||
response = self._make_signed_request(
|
||||
method="POST", path=path, data=request_params
|
||||
)
|
||||
|
||||
try:
|
||||
response = json.loads(rsp.text)
|
||||
response_data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse JSON response: {e}")
|
||||
|
||||
if response["code"] != 0:
|
||||
if response_data["code"] != 0:
|
||||
raise ValueError(
|
||||
f"Failed to query documents from resource: {response['message']}"
|
||||
f"Failed to query documents from resource: {response_data['message']}"
|
||||
)
|
||||
|
||||
rsp_data = response.get("data", {})
|
||||
rsp_data = response_data.get("data", {})
|
||||
|
||||
if "result_list" not in rsp_data:
|
||||
continue
|
||||
@@ -163,25 +257,20 @@ class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
"""
|
||||
List resources (knowledge bases) from the knowledge base service
|
||||
"""
|
||||
method = "POST"
|
||||
path = "/api/knowledge/collection/list"
|
||||
info_req = self.prepare_request(method=method, path=path)
|
||||
rsp = requests.request(
|
||||
method=info_req.method,
|
||||
url="http://{}{}".format(self.api_url, info_req.path),
|
||||
headers=info_req.headers,
|
||||
data=info_req.body,
|
||||
)
|
||||
|
||||
response = self._make_signed_request(method="POST", path=path)
|
||||
|
||||
try:
|
||||
response = json.loads(rsp.text)
|
||||
response_data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse JSON response: {e}")
|
||||
|
||||
if response["code"] != 0:
|
||||
raise Exception(f"Failed to list resources: {response["message"]}")
|
||||
if response_data["code"] != 0:
|
||||
raise Exception(f"Failed to list resources: {response_data['message']}")
|
||||
|
||||
resources = []
|
||||
rsp_data = response.get("data", {})
|
||||
rsp_data = response_data.get("data", {})
|
||||
collection_list = rsp_data.get("collection_list", [])
|
||||
for item in collection_list:
|
||||
collection_name = item.get("collection_name", "")
|
||||
|
||||
Reference in New Issue
Block a user