mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat: integrate VikingDB Knowledge Base into rag retrieving tool (#381)
Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
@@ -14,6 +14,12 @@ TAVILY_API_KEY=tvly-xxx
|
||||
# JINA_API_KEY=jina_xxx # Optional, default is None
|
||||
|
||||
# Optional, RAG provider
|
||||
# RAG_PROVIDER=vikingdb_knowledge_base
|
||||
# VIKINGDB_KNOWLEDGE_BASE_API_URL="api-knowledgebase.mlp.cn-beijing.volces.com"
|
||||
# VIKINGDB_KNOWLEDGE_BASE_API_AK="AKxxx"
|
||||
# VIKINGDB_KNOWLEDGE_BASE_API_SK=""
|
||||
# VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE=15
|
||||
|
||||
# RAG_PROVIDER=ragflow
|
||||
# RAGFLOW_API_URL="http://localhost:9388"
|
||||
# RAGFLOW_API_KEY="ragflow-xxx"
|
||||
|
||||
@@ -33,6 +33,7 @@ dependencies = [
|
||||
"mcp>=1.6.0",
|
||||
"langchain-mcp-adapters>=0.0.9",
|
||||
"langchain-deepseek>=0.1.3",
|
||||
"volcengine>=1.0.191",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -21,6 +21,7 @@ SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
|
||||
|
||||
class RAGProvider(enum.Enum):
|
||||
RAGFLOW = "ragflow"
|
||||
VIKINGDB_KNOWLEDGE_BASE = "vikingdb_knowledge_base"
|
||||
|
||||
|
||||
SELECTED_RAG_PROVIDER = os.getenv("RAG_PROVIDER")
|
||||
|
||||
@@ -3,6 +3,15 @@
|
||||
|
||||
from .retriever import Retriever, Document, Resource, Chunk
|
||||
from .ragflow import RAGFlowProvider
|
||||
from .vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||
from .builder import build_retriever
|
||||
|
||||
__all__ = [Retriever, Document, Resource, RAGFlowProvider, Chunk, build_retriever]
|
||||
__all__ = [
|
||||
Retriever,
|
||||
Document,
|
||||
Resource,
|
||||
RAGFlowProvider,
|
||||
VikingDBKnowledgeBaseProvider,
|
||||
Chunk,
|
||||
build_retriever,
|
||||
]
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
|
||||
from src.config.tools import SELECTED_RAG_PROVIDER, RAGProvider
|
||||
from src.rag.ragflow import RAGFlowProvider
|
||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider
|
||||
from src.rag.retriever import Retriever
|
||||
|
||||
|
||||
def build_retriever() -> Retriever | None:
|
||||
if SELECTED_RAG_PROVIDER == RAGProvider.RAGFLOW.value:
|
||||
return RAGFlowProvider()
|
||||
elif SELECTED_RAG_PROVIDER == RAGProvider.VIKINGDB_KNOWLEDGE_BASE.value:
|
||||
return VikingDBKnowledgeBaseProvider()
|
||||
elif SELECTED_RAG_PROVIDER:
|
||||
raise ValueError(f"Unsupported RAG provider: {SELECTED_RAG_PROVIDER}")
|
||||
return None
|
||||
|
||||
208
src/rag/vikingdb_knowledge_base.py
Normal file
208
src/rag/vikingdb_knowledge_base.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
from urllib.parse import urlparse
|
||||
from volcengine.auth.SignerV4 import SignerV4
|
||||
from volcengine.base.Request import Request
|
||||
from volcengine.Credentials import Credentials
|
||||
|
||||
|
||||
class VikingDBKnowledgeBaseProvider(Retriever):
|
||||
"""
|
||||
VikingDBKnowledgeBaseProvider is a provider that uses VikingDB Knowledge base API to retrieve documents.
|
||||
"""
|
||||
|
||||
api_url: str
|
||||
api_ak: str
|
||||
api_sk: str
|
||||
retrieval_size: int = 10
|
||||
|
||||
def __init__(self):
|
||||
api_url = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_URL")
|
||||
if not api_url:
|
||||
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_URL is not set")
|
||||
self.api_url = api_url
|
||||
|
||||
api_ak = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_AK")
|
||||
if not api_ak:
|
||||
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_AK is not set")
|
||||
self.api_ak = api_ak
|
||||
|
||||
api_sk = os.getenv("VIKINGDB_KNOWLEDGE_BASE_API_SK")
|
||||
if not api_sk:
|
||||
raise ValueError("VIKINGDB_KNOWLEDGE_BASE_API_SK is not set")
|
||||
self.api_sk = api_sk
|
||||
|
||||
retrieval_size = os.getenv("VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE")
|
||||
if retrieval_size:
|
||||
self.retrieval_size = int(retrieval_size)
|
||||
|
||||
def prepare_request(self, method, path, params=None, data=None, doseq=0):
|
||||
"""
|
||||
Prepare signed request using volcengine auth
|
||||
"""
|
||||
if params:
|
||||
for key in params:
|
||||
if (
|
||||
type(params[key]) == int
|
||||
or type(params[key]) == float
|
||||
or type(params[key]) == bool
|
||||
):
|
||||
params[key] = str(params[key])
|
||||
elif type(params[key]) == list:
|
||||
if not doseq:
|
||||
params[key] = ",".join(params[key])
|
||||
|
||||
r = Request()
|
||||
r.set_shema("https")
|
||||
r.set_method(method)
|
||||
r.set_connection_timeout(10)
|
||||
r.set_socket_timeout(10)
|
||||
mheaders = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
r.set_headers(mheaders)
|
||||
if params:
|
||||
r.set_query(params)
|
||||
r.set_path(path)
|
||||
if data is not None:
|
||||
r.set_body(json.dumps(data))
|
||||
|
||||
credentials = Credentials(self.api_ak, self.api_sk, "air", "cn-north-1")
|
||||
SignerV4.sign(r, credentials)
|
||||
return r
|
||||
|
||||
def query_relevant_documents(
|
||||
self, query: str, resources: list[Resource] = []
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Query relevant documents from the knowledge base
|
||||
"""
|
||||
if not resources:
|
||||
return []
|
||||
|
||||
all_documents = {}
|
||||
for resource in resources:
|
||||
resource_id, document_id = parse_uri(resource.uri)
|
||||
request_params = {
|
||||
"resource_id": resource_id,
|
||||
"query": query,
|
||||
"limit": self.retrieval_size,
|
||||
"dense_weight": 0.5,
|
||||
"pre_processing": {
|
||||
"need_instruction": True,
|
||||
"rewrite": False,
|
||||
"return_token_usage": True,
|
||||
},
|
||||
"post_processing": {
|
||||
"rerank_switch": True,
|
||||
"chunk_diffusion_count": 0,
|
||||
"chunk_group": True,
|
||||
"get_attachment_link": True,
|
||||
},
|
||||
}
|
||||
if document_id:
|
||||
doc_filter = {"op": "must", "field": "doc_id", "conds": [document_id]}
|
||||
query_param = {"doc_filter": doc_filter}
|
||||
request_params["query_param"] = query_param
|
||||
|
||||
method = "POST"
|
||||
path = "/api/knowledge/collection/search_knowledge"
|
||||
info_req = self.prepare_request(
|
||||
method=method, path=path, data=request_params
|
||||
)
|
||||
rsp = requests.request(
|
||||
method=info_req.method,
|
||||
url="http://{}{}".format(self.api_url, info_req.path),
|
||||
headers=info_req.headers,
|
||||
data=info_req.body,
|
||||
)
|
||||
|
||||
try:
|
||||
response = json.loads(rsp.text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse JSON response: {e}")
|
||||
|
||||
if response["code"] != 0:
|
||||
raise ValueError(
|
||||
f"Failed to query documents from resource: {response['message']}"
|
||||
)
|
||||
|
||||
rsp_data = response.get("data", {})
|
||||
|
||||
if "result_list" not in rsp_data:
|
||||
continue
|
||||
|
||||
result_list = rsp_data["result_list"]
|
||||
|
||||
for item in result_list:
|
||||
doc_info = item.get("doc_info", {})
|
||||
doc_id = doc_info.get("doc_id")
|
||||
|
||||
if not doc_id:
|
||||
continue
|
||||
|
||||
if doc_id not in all_documents:
|
||||
all_documents[doc_id] = Document(
|
||||
id=doc_id, title=doc_info.get("doc_name"), chunks=[]
|
||||
)
|
||||
|
||||
chunk = Chunk(
|
||||
content=item.get("content", ""), similarity=item.get("score", 0.0)
|
||||
)
|
||||
all_documents[doc_id].chunks.append(chunk)
|
||||
|
||||
return list(all_documents.values())
|
||||
|
||||
def list_resources(self, query: str | None = None) -> list[Resource]:
|
||||
"""
|
||||
List resources (knowledge bases) from the knowledge base service
|
||||
"""
|
||||
method = "POST"
|
||||
path = "/api/knowledge/collection/list"
|
||||
info_req = self.prepare_request(method=method, path=path)
|
||||
rsp = requests.request(
|
||||
method=info_req.method,
|
||||
url="http://{}{}".format(self.api_url, info_req.path),
|
||||
headers=info_req.headers,
|
||||
data=info_req.body,
|
||||
)
|
||||
try:
|
||||
response = json.loads(rsp.text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse JSON response: {e}")
|
||||
|
||||
if response["code"] != 0:
|
||||
raise Exception(f"Failed to list resources: {response["message"]}")
|
||||
|
||||
resources = []
|
||||
rsp_data = response.get("data", {})
|
||||
collection_list = rsp_data.get("collection_list", [])
|
||||
for item in collection_list:
|
||||
collection_name = item.get("collection_name", "")
|
||||
description = item.get("description", "")
|
||||
|
||||
if query and query.lower() not in collection_name.lower():
|
||||
continue
|
||||
|
||||
resource_id = item.get("resource_id", "")
|
||||
resource = Resource(
|
||||
uri=f"rag://dataset/{resource_id}",
|
||||
title=collection_name,
|
||||
description=description,
|
||||
)
|
||||
resources.append(resource)
|
||||
|
||||
return resources
|
||||
|
||||
|
||||
def parse_uri(uri: str) -> tuple[str, str]:
|
||||
parsed = urlparse(uri)
|
||||
if parsed.scheme != "rag":
|
||||
raise ValueError(f"Invalid URI: {uri}")
|
||||
return parsed.path.split("/")[1], parsed.fragment
|
||||
503
tests/unit/rag/test_vikingdb_knowledge_base.py
Normal file
503
tests/unit/rag/test_vikingdb_knowledge_base.py
Normal file
@@ -0,0 +1,503 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider, parse_uri
|
||||
|
||||
|
||||
# Dummy classes to mock dependencies
|
||||
class MockResource:
|
||||
def __init__(self, uri, title="", description=""):
|
||||
self.uri = uri
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
|
||||
class MockChunk:
|
||||
def __init__(self, content, similarity):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class MockDocument:
|
||||
def __init__(self, id, title, chunks=None):
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.chunks = chunks or []
|
||||
|
||||
|
||||
# Patch the imports to use mock classes
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_imports():
|
||||
with (
|
||||
patch("src.rag.vikingdb_knowledge_base.Resource", MockResource),
|
||||
patch("src.rag.vikingdb_knowledge_base.Chunk", MockChunk),
|
||||
patch("src.rag.vikingdb_knowledge_base.Document", MockDocument),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env_vars():
|
||||
"""Fixture to set up environment variables"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "10",
|
||||
},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestParseUri:
|
||||
def test_parse_uri_valid_with_fragment(self):
|
||||
"""Test parsing valid URI with fragment"""
|
||||
uri = "rag://dataset/123#doc456"
|
||||
resource_id, document_id = parse_uri(uri)
|
||||
assert resource_id == "123"
|
||||
assert document_id == "doc456"
|
||||
|
||||
def test_parse_uri_valid_without_fragment(self):
|
||||
"""Test parsing valid URI without fragment"""
|
||||
uri = "rag://dataset/123"
|
||||
resource_id, document_id = parse_uri(uri)
|
||||
assert resource_id == "123"
|
||||
assert document_id == ""
|
||||
|
||||
def test_parse_uri_invalid_scheme(self):
|
||||
"""Test parsing URI with invalid scheme"""
|
||||
with pytest.raises(ValueError, match="Invalid URI"):
|
||||
parse_uri("http://dataset/123#abc")
|
||||
|
||||
def test_parse_uri_malformed(self):
|
||||
"""Test parsing malformed URI"""
|
||||
with pytest.raises(ValueError, match="Invalid URI"):
|
||||
parse_uri("invalid_uri")
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderInit:
|
||||
def test_init_success_with_all_env_vars(self, env_vars):
|
||||
"""Test successful initialization with all environment variables"""
|
||||
provider = VikingDBKnowledgeBaseProvider()
|
||||
assert provider.api_url == "api-test.example.com"
|
||||
assert provider.api_ak == "test_ak"
|
||||
assert provider.api_sk == "test_sk"
|
||||
assert provider.retrieval_size == 10
|
||||
|
||||
def test_init_success_without_retrieval_size(self):
|
||||
"""Test initialization without VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE (should use default)"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
provider = VikingDBKnowledgeBaseProvider()
|
||||
assert provider.retrieval_size == 10
|
||||
|
||||
def test_init_custom_retrieval_size(self):
|
||||
"""Test initialization with custom retrieval size"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "5",
|
||||
},
|
||||
):
|
||||
provider = VikingDBKnowledgeBaseProvider()
|
||||
assert provider.retrieval_size == 5
|
||||
|
||||
def test_init_missing_api_url(self):
|
||||
"""Test initialization fails when API URL is missing"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_URL is not set"
|
||||
):
|
||||
VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_init_missing_api_ak(self):
|
||||
"""Test initialization fails when API AK is missing"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_AK is not set"
|
||||
):
|
||||
VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_init_missing_api_sk(self):
|
||||
"""Test initialization fails when API SK is missing"""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
|
||||
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_SK is not set"
|
||||
):
|
||||
VikingDBKnowledgeBaseProvider()
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderPrepareRequest:
|
||||
@pytest.fixture
|
||||
def provider(self, env_vars):
|
||||
return VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_prepare_request_basic(self, provider):
|
||||
"""Test basic request preparation"""
|
||||
with (
|
||||
patch("src.rag.vikingdb_knowledge_base.Request") as mock_request,
|
||||
patch("src.rag.vikingdb_knowledge_base.Credentials") as mock_credentials,
|
||||
patch("src.rag.vikingdb_knowledge_base.SignerV4.sign") as mock_sign,
|
||||
):
|
||||
|
||||
mock_req_instance = MagicMock()
|
||||
mock_request.return_value = mock_req_instance
|
||||
|
||||
result = provider.prepare_request("POST", "/test/path")
|
||||
|
||||
assert result == mock_req_instance
|
||||
mock_req_instance.set_shema.assert_called_once_with("https")
|
||||
mock_req_instance.set_method.assert_called_once_with("POST")
|
||||
mock_req_instance.set_path.assert_called_once_with("/test/path")
|
||||
|
||||
def test_prepare_request_with_params(self, provider):
|
||||
"""Test request preparation with parameters"""
|
||||
with (
|
||||
patch("src.rag.vikingdb_knowledge_base.Request") as mock_request,
|
||||
patch("src.rag.vikingdb_knowledge_base.Credentials"),
|
||||
patch("src.rag.vikingdb_knowledge_base.SignerV4.sign"),
|
||||
):
|
||||
|
||||
mock_req_instance = MagicMock()
|
||||
mock_request.return_value = mock_req_instance
|
||||
|
||||
params = {"key": "value", "number": 123, "boolean": True}
|
||||
provider.prepare_request("GET", "/test", params=params)
|
||||
|
||||
expected_params = {"key": "value", "number": "123", "boolean": "True"}
|
||||
mock_req_instance.set_query.assert_called_once_with(expected_params)
|
||||
|
||||
def test_prepare_request_with_data(self, provider):
|
||||
"""Test request preparation with data"""
|
||||
with (
|
||||
patch("src.rag.vikingdb_knowledge_base.Request") as mock_request,
|
||||
patch("src.rag.vikingdb_knowledge_base.Credentials"),
|
||||
patch("src.rag.vikingdb_knowledge_base.SignerV4.sign"),
|
||||
):
|
||||
|
||||
mock_req_instance = MagicMock()
|
||||
mock_request.return_value = mock_req_instance
|
||||
|
||||
data = {"test": "data"}
|
||||
provider.prepare_request("POST", "/test", data=data)
|
||||
|
||||
mock_req_instance.set_body.assert_called_once_with(json.dumps(data))
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderQueryRelevantDocuments:
|
||||
@pytest.fixture
|
||||
def provider(self, env_vars):
|
||||
return VikingDBKnowledgeBaseProvider()
|
||||
|
||||
def test_query_relevant_documents_empty_resources(self, provider):
|
||||
"""Test querying with empty resources list"""
|
||||
result = provider.query_relevant_documents("test query", [])
|
||||
assert result == []
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_query_relevant_documents_success(self, mock_request, provider):
|
||||
"""Test successful document query"""
|
||||
# Mock response
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"result_list": [
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc123",
|
||||
"doc_name": "Test Document",
|
||||
},
|
||||
"content": "Test content",
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
# Mock prepare_request
|
||||
with patch.object(provider, "prepare_request") as mock_prepare:
|
||||
mock_req = MagicMock()
|
||||
mock_req.method = "POST"
|
||||
mock_req.path = "/api/knowledge/collection/search_knowledge"
|
||||
mock_req.headers = {}
|
||||
mock_req.body = "{}"
|
||||
mock_prepare.return_value = mock_req
|
||||
|
||||
resources = [MockResource("rag://dataset/123")]
|
||||
result = provider.query_relevant_documents("test query", resources)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "doc123"
|
||||
assert result[0].title == "Test Document"
|
||||
assert len(result[0].chunks) == 1
|
||||
assert result[0].chunks[0].content == "Test content"
|
||||
assert result[0].chunks[0].similarity == 0.95
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_query_relevant_documents_with_document_filter(
|
||||
self, mock_request, provider
|
||||
):
|
||||
"""Test document query with document ID filter"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps({"code": 0, "data": {"result_list": []}})
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request") as mock_prepare:
|
||||
mock_req = MagicMock()
|
||||
mock_prepare.return_value = mock_req
|
||||
|
||||
resources = [MockResource("rag://dataset/123#doc456")]
|
||||
provider.query_relevant_documents("test query", resources)
|
||||
|
||||
# Verify that query_param with doc_filter was included in the request
|
||||
call_args = mock_prepare.call_args
|
||||
request_data = call_args[1]["data"]
|
||||
assert "query_param" in request_data
|
||||
assert "doc_filter" in request_data["query_param"]
|
||||
|
||||
doc_filter = request_data["query_param"]["doc_filter"]
|
||||
assert doc_filter["op"] == "must"
|
||||
assert doc_filter["field"] == "doc_id"
|
||||
assert doc_filter["conds"] == ["doc456"]
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_query_relevant_documents_api_error(self, mock_request, provider):
|
||||
"""Test handling of API error response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps({"code": 1, "message": "API Error"})
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request"):
|
||||
resources = [MockResource("rag://dataset/123")]
|
||||
with pytest.raises(
|
||||
ValueError, match="Failed to query documents from resource: API Error"
|
||||
):
|
||||
provider.query_relevant_documents("test query", resources)
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_query_relevant_documents_json_decode_error(self, mock_request, provider):
|
||||
"""Test handling of JSON decode error"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "invalid json"
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request"):
|
||||
resources = [MockResource("rag://dataset/123")]
|
||||
with pytest.raises(ValueError, match="Failed to parse JSON response"):
|
||||
provider.query_relevant_documents("test query", resources)
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_query_relevant_documents_multiple_resources(self, mock_request, provider):
|
||||
"""Test querying multiple resources and merging results"""
|
||||
# Mock responses for different resources
|
||||
responses = [
|
||||
json.dumps(
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"result_list": [
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc1",
|
||||
"doc_name": "Document 1",
|
||||
},
|
||||
"content": "Content 1",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"result_list": [
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc1",
|
||||
"doc_name": "Document 1",
|
||||
},
|
||||
"content": "Content 2",
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"doc_info": {
|
||||
"doc_id": "doc2",
|
||||
"doc_name": "Document 2",
|
||||
},
|
||||
"content": "Content 3",
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
mock_request.side_effect = [MagicMock(text=resp) for resp in responses]
|
||||
|
||||
with patch.object(provider, "prepare_request"):
|
||||
resources = [
|
||||
MockResource("rag://dataset/123"),
|
||||
MockResource("rag://dataset/456"),
|
||||
]
|
||||
result = provider.query_relevant_documents("test query", resources)
|
||||
|
||||
# Should have 2 documents: doc1 (with 2 chunks) and doc2 (with 1 chunk)
|
||||
assert len(result) == 2
|
||||
doc1 = next(doc for doc in result if doc.id == "doc1")
|
||||
doc2 = next(doc for doc in result if doc.id == "doc2")
|
||||
assert len(doc1.chunks) == 2
|
||||
assert len(doc2.chunks) == 1
|
||||
|
||||
|
||||
class TestVikingDBKnowledgeBaseProviderListResources:
|
||||
@pytest.fixture
|
||||
def provider(self, env_vars):
|
||||
return VikingDBKnowledgeBaseProvider()
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_list_resources_success(self, mock_request, provider):
|
||||
"""Test successful resource listing"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"collection_list": [
|
||||
{
|
||||
"resource_id": "123",
|
||||
"collection_name": "Dataset 1",
|
||||
"description": "Description 1",
|
||||
},
|
||||
{
|
||||
"resource_id": "456",
|
||||
"collection_name": "Dataset 2",
|
||||
"description": "Description 2",
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request") as mock_prepare:
|
||||
mock_req = MagicMock()
|
||||
mock_prepare.return_value = mock_req
|
||||
|
||||
result = provider.list_resources()
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].uri == "rag://dataset/123"
|
||||
assert result[0].title == "Dataset 1"
|
||||
assert result[0].description == "Description 1"
|
||||
assert result[1].uri == "rag://dataset/456"
|
||||
assert result[1].title == "Dataset 2"
|
||||
assert result[1].description == "Description 2"
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_list_resources_with_query_filter(self, mock_request, provider):
|
||||
"""Test resource listing with query filter"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps(
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"collection_list": [
|
||||
{
|
||||
"resource_id": "123",
|
||||
"collection_name": "Test Dataset",
|
||||
"description": "Description",
|
||||
},
|
||||
{
|
||||
"resource_id": "456",
|
||||
"collection_name": "Other Dataset",
|
||||
"description": "Description",
|
||||
},
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request"):
|
||||
result = provider.list_resources("test")
|
||||
|
||||
# Should only return the dataset with "test" in the name
|
||||
assert len(result) == 1
|
||||
assert result[0].title == "Test Dataset"
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_list_resources_api_error(self, mock_request, provider):
|
||||
"""Test handling of API error in list_resources"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps({"code": 1, "message": "API Error"})
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request"):
|
||||
with pytest.raises(Exception, match="Failed to list resources: API Error"):
|
||||
provider.list_resources()
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_list_resources_json_decode_error(self, mock_request, provider):
|
||||
"""Test handling of JSON decode error in list_resources"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "invalid json"
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request"):
|
||||
with pytest.raises(ValueError, match="Failed to parse JSON response"):
|
||||
provider.list_resources()
|
||||
|
||||
@patch("src.rag.vikingdb_knowledge_base.requests.request")
|
||||
def test_list_resources_empty_response(self, mock_request, provider):
|
||||
"""Test handling of empty response"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = json.dumps({"code": 0, "data": {"collection_list": []}})
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
with patch.object(provider, "prepare_request"):
|
||||
result = provider.list_resources()
|
||||
assert result == []
|
||||
84
uv.lock
generated
84
uv.lock
generated
@@ -365,6 +365,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
version = "5.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deer-flow"
|
||||
version = "0.1.0"
|
||||
@@ -393,6 +402,7 @@ dependencies = [
|
||||
{ name = "socksio" },
|
||||
{ name = "sse-starlette" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "volcengine" },
|
||||
{ name = "yfinance" },
|
||||
]
|
||||
|
||||
@@ -437,6 +447,7 @@ requires-dist = [
|
||||
{ name = "socksio", specifier = ">=1.0.0" },
|
||||
{ name = "sse-starlette", specifier = ">=1.6.5" },
|
||||
{ name = "uvicorn", specifier = ">=0.27.1" },
|
||||
{ name = "volcengine", specifier = ">=1.0.191" },
|
||||
{ name = "yfinance", specifier = ">=0.2.54" },
|
||||
]
|
||||
provides-extras = ["dev", "test"]
|
||||
@@ -564,6 +575,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615, upload-time = "2025-03-07T21:47:54.809Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "beautifulsoup4" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/89/97/b49c69893cddea912c7a660a4b6102c6b02cd268f8c7162dd70b7c16f753/google-3.0.0.tar.gz", hash = "sha256:143530122ee5130509ad5e989f0512f7cb218b2d4eddbafbad40fd10e8d8ccbe", size = 44978, upload-time = "2020-07-11T14:50:45.678Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/35/17c9141c4ae21e9a29a43acdfd848e3e468a810517f862cad07977bf8fe9/google-3.0.0-py2.py3-none-any.whl", hash = "sha256:889cf695f84e4ae2c55fbc0cfdaf4c1e729417fa52ab1db0485202ba173e4935", size = 45258, upload-time = "2020-07-11T14:49:58.287Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "greenlet"
|
||||
version = "3.1.1"
|
||||
@@ -1595,6 +1618,29 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/35/6c4c6fc8774a9e3629cd750dc24a7a4fb090a25ccd5c3246d127b70f9e22/propcache-0.3.0-py3-none-any.whl", hash = "sha256:67dda3c7325691c2081510e92c561f465ba61b975f481735aefdfc845d2cd043", size = 12101, upload-time = "2025-02-20T19:03:27.202Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "6.31.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/52/f3/b9655a711b32c19720253f6f06326faf90580834e2e83f840472d752bc8b/protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a", size = 441797, upload-time = "2025-05-28T19:25:54.947Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/6f/6ab8e4bf962fd5570d3deaa2d5c38f0a363f57b4501047b5ebeb83ab1125/protobuf-6.31.1-cp310-abi3-win32.whl", hash = "sha256:7fa17d5a29c2e04b7d90e5e32388b8bfd0e7107cd8e616feef7ed3fa6bdab5c9", size = 423603, upload-time = "2025-05-28T19:25:41.198Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/44/3a/b15c4347dd4bf3a1b0ee882f384623e2063bb5cf9fa9d57990a4f7df2fb6/protobuf-6.31.1-cp310-abi3-win_amd64.whl", hash = "sha256:426f59d2964864a1a366254fa703b8632dcec0790d8862d30034d8245e1cd447", size = 435283, upload-time = "2025-05-28T19:25:44.275Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/c9/b9689a2a250264a84e66c46d8862ba788ee7a641cdca39bccf64f59284b7/protobuf-6.31.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:6f1227473dc43d44ed644425268eb7c2e488ae245d51c6866d19fe158e207402", size = 425604, upload-time = "2025-05-28T19:25:45.702Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/a1/7a5a94032c83375e4fe7e7f56e3976ea6ac90c5e85fac8576409e25c39c3/protobuf-6.31.1-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:a40fc12b84c154884d7d4c4ebd675d5b3b5283e155f324049ae396b95ddebc39", size = 322115, upload-time = "2025-05-28T19:25:47.128Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/b1/b59d405d64d31999244643d88c45c8241c58f17cc887e73bcb90602327f8/protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:4ee898bf66f7a8b0bd21bce523814e6fbd8c6add948045ce958b73af7e8878c6", size = 321070, upload-time = "2025-05-28T19:25:50.036Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/af/ab3c51ab7507a7325e98ffe691d9495ee3d3aa5f589afad65ec920d39821/protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e", size = 168724, upload-time = "2025-05-28T19:25:53.926Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py"
|
||||
version = "1.11.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/98/ff/fec109ceb715d2a6b4c4a85a61af3b40c723a961e8828319fbcb15b868dc/py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719", size = 207796, upload-time = "2021-11-04T17:17:01.377Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/f0/10642828a8dfb741e5f3fbaac830550a518a775c7fff6f04a007259b0548/py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378", size = 98708, upload-time = "2021-11-04T17:17:00.152Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
@@ -1604,6 +1650,12 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552, upload-time = "2024-03-30T13:22:20.476Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycryptodome"
|
||||
version = "3.9.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c4/3a/5bca2cb1648b171afd6b7d29a11c6bca8b305bb75b7e2d78a0f5c61ff95e/pycryptodome-3.9.9.tar.gz", hash = "sha256:910e202a557e1131b1c1b3f17a63914d57aac55cf9fb9b51644962841c3995c4", size = 15488528, upload-time = "2020-11-03T13:15:26.723Z" }
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.10.6"
|
||||
@@ -1701,9 +1753,9 @@ source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/d4/14f53324cb1a6381bef29d698987625d80052bb33932d8e7cbf9b337b17c/pytest_asyncio-1.0.0.tar.gz", hash = "sha256:d15463d13f4456e1ead2594520216b225a16f781e144f8fdf6c5bb4667c48b3f", size = 46960, upload-time = "2025-05-26T04:54:40.484Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976 },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/05/ce271016e351fddc8399e546f6e23761967ee09c8c568bbfbecb0c150171/pytest_asyncio-1.0.0-py3-none-any.whl", hash = "sha256:4f024da9f1ef945e680dc68610b52550e36590a67fd31bb3b4943979a1f90ef3", size = 15976, upload-time = "2025-05-26T04:54:39.035Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1869,6 +1921,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl", hash = "sha256:cccfdd665f0a24fcf4726e690f65639d272bb0637b9b92dfd91a5568ccf6bd06", size = 54481, upload-time = "2023-05-01T04:11:28.427Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "retry"
|
||||
version = "0.9.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "decorator" },
|
||||
{ name = "py" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9d/72/75d0b85443fbc8d9f38d08d2b1b67cc184ce35280e4a3813cda2f445f3a4/retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4", size = 6448, upload-time = "2016-05-11T13:58:51.541Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/0d/53aea75710af4528a25ed6837d71d117602b01946b307a3912cb3cfcbcba/retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606", size = 7986, upload-time = "2016-05-11T13:58:39.925Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rpds-py"
|
||||
version = "0.23.1"
|
||||
@@ -2154,6 +2219,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/14/33a3a1352cfa71812a3a21e8c9bfb83f60b0011f5e36f2b1399d51928209/uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4", size = 62315, upload-time = "2024-12-15T13:33:27.467Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "volcengine"
|
||||
version = "1.0.191"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "google" },
|
||||
{ name = "protobuf" },
|
||||
{ name = "pycryptodome" },
|
||||
{ name = "pytz" },
|
||||
{ name = "requests" },
|
||||
{ name = "retry" },
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/d8/0ea9b18f216808af709306084d10369f712b98cb5381381d44115dfa6536/volcengine-1.0.191.tar.gz", hash = "sha256:cf3c2dc118c92a7a47f1ab8a48f4789d47e84d17778a1717e1afe9cbce90c986", size = 356076, upload-time = "2025-06-26T12:25:16.353Z" }
|
||||
|
||||
[[package]]
|
||||
name = "watchfiles"
|
||||
version = "1.0.5"
|
||||
|
||||
Reference in New Issue
Block a user