Files
deer-flow/tests/unit/rag/test_vikingdb_knowledge_base.py
zgjja 3b4e993531 feat: 1. replace black with ruff for fomatting and sort import (#489)
2. use tavily from`langchain-tavily` rather than the older one from `langchain-community`

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2025-08-17 22:57:23 +08:00

541 lines
20 KiB
Python

# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import hashlib
import hmac
import json
import os
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from src.rag.vikingdb_knowledge_base import VikingDBKnowledgeBaseProvider, parse_uri
# Dummy classes to mock dependencies
class MockResource:
def __init__(self, uri, title="", description=""):
self.uri = uri
self.title = title
self.description = description
class MockChunk:
def __init__(self, content, similarity):
self.content = content
self.similarity = similarity
class MockDocument:
def __init__(self, id, title, chunks=None):
self.id = id
self.title = title
self.chunks = chunks or []
# Patch the imports to use mock classes
@pytest.fixture(autouse=True)
def patch_imports():
with (
patch("src.rag.vikingdb_knowledge_base.Resource", MockResource),
patch("src.rag.vikingdb_knowledge_base.Chunk", MockChunk),
patch("src.rag.vikingdb_knowledge_base.Document", MockDocument),
):
yield
@pytest.fixture
def env_vars():
"""Fixture to set up environment variables"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "10",
"VIKINGDB_KNOWLEDGE_BASE_REGION": "cn-north-1",
},
):
yield
class TestParseUri:
def test_parse_uri_valid_with_fragment(self):
"""Test parsing valid URI with fragment"""
uri = "rag://dataset/123#doc456"
resource_id, document_id = parse_uri(uri)
assert resource_id == "123"
assert document_id == "doc456"
def test_parse_uri_valid_without_fragment(self):
"""Test parsing valid URI without fragment"""
uri = "rag://dataset/123"
resource_id, document_id = parse_uri(uri)
assert resource_id == "123"
assert document_id == ""
def test_parse_uri_invalid_scheme(self):
"""Test parsing URI with invalid scheme"""
with pytest.raises(ValueError, match="Invalid URI"):
parse_uri("http://dataset/123#abc")
def test_parse_uri_malformed(self):
"""Test parsing malformed URI"""
with pytest.raises(ValueError, match="Invalid URI"):
parse_uri("invalid_uri")
class TestVikingDBKnowledgeBaseProviderInit:
def test_init_success_with_all_env_vars(self, env_vars):
"""Test successful initialization with all environment variables"""
provider = VikingDBKnowledgeBaseProvider()
assert provider.api_url == "api-test.example.com"
assert provider.api_ak == "test_ak"
assert provider.api_sk == "test_sk"
assert provider.retrieval_size == 10
assert provider.region == "cn-north-1"
assert provider.service == "air"
def test_init_success_without_retrieval_size(self):
"""Test initialization without VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE (should use default)"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
},
clear=True,
):
provider = VikingDBKnowledgeBaseProvider()
assert provider.retrieval_size == 10
def test_init_custom_retrieval_size(self):
"""Test initialization with custom retrieval size"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
"VIKINGDB_KNOWLEDGE_BASE_RETRIEVAL_SIZE": "5",
},
):
provider = VikingDBKnowledgeBaseProvider()
assert provider.retrieval_size == 5
def test_init_custom_region(self):
"""Test initialization with custom region"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
"VIKINGDB_KNOWLEDGE_BASE_REGION": "us-east-1",
},
):
provider = VikingDBKnowledgeBaseProvider()
assert provider.region == "us-east-1"
def test_init_missing_api_url(self):
"""Test initialization fails when API URL is missing"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
},
clear=True,
):
with pytest.raises(
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_URL is not set"
):
VikingDBKnowledgeBaseProvider()
def test_init_missing_api_ak(self):
"""Test initialization fails when API AK is missing"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_SK": "test_sk",
},
clear=True,
):
with pytest.raises(
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_AK is not set"
):
VikingDBKnowledgeBaseProvider()
def test_init_missing_api_sk(self):
"""Test initialization fails when API SK is missing"""
with patch.dict(
os.environ,
{
"VIKINGDB_KNOWLEDGE_BASE_API_URL": "api-test.example.com",
"VIKINGDB_KNOWLEDGE_BASE_API_AK": "test_ak",
},
clear=True,
):
with pytest.raises(
ValueError, match="VIKINGDB_KNOWLEDGE_BASE_API_SK is not set"
):
VikingDBKnowledgeBaseProvider()
class TestVikingDBKnowledgeBaseProviderSignature:
@pytest.fixture
def provider(self, env_vars):
return VikingDBKnowledgeBaseProvider()
def test_hmac_sha256(self, provider):
"""Test HMAC SHA256 calculation"""
key = b"test_key"
content = "test_content"
result = provider._hmac_sha256(key, content)
expected = hmac.new(key, content.encode("utf-8"), hashlib.sha256).digest()
assert result == expected
def test_hash_sha256(self, provider):
"""Test SHA256 hash calculation"""
data = b"test_data"
result = provider._hash_sha256(data)
expected = hashlib.sha256(data).digest()
assert result == expected
def test_get_signed_key(self, provider):
"""Test signed key generation"""
secret_key = "test_secret"
date = "20250722"
region = "cn-north-1"
service = "air"
result = provider._get_signed_key(secret_key, date, region, service)
assert isinstance(result, bytes)
assert len(result) == 32 # SHA256 digest is 32 bytes
def test_create_canonical_request(self, provider):
"""Test canonical request creation"""
method = "POST"
path = "/api/test"
query_params = {"param1": "value1", "param2": "value2"}
headers = {"Content-Type": "application/json", "Host": "example.com"}
payload = b'{"test": "data"}'
canonical_request, signed_headers = provider._create_canonical_request(
method, path, query_params, headers, payload
)
assert "POST" in canonical_request
assert "/api/test" in canonical_request
assert "param1=value1&param2=value2" in canonical_request
assert "content-type:application/json" in canonical_request
assert "host:example.com" in canonical_request
assert signed_headers == "content-type;host"
@patch("src.rag.vikingdb_knowledge_base.datetime")
def test_create_signature(self, mock_datetime, provider):
"""Test signature creation"""
# Mock datetime
mock_now = datetime(2025, 7, 22, 10, 30, 45)
mock_datetime.utcnow.return_value = mock_now
method = "POST"
path = "/api/test"
query_params = {}
headers = {}
payload = b'{"test": "data"}'
result = provider._create_signature(
method, path, query_params, headers, payload
)
assert "X-Date" in result
assert "Host" in result
assert "X-Content-Sha256" in result
assert "Content-Type" in result
assert "Authorization" in result
assert "HMAC-SHA256" in result["Authorization"]
@patch("src.rag.vikingdb_knowledge_base.requests.request")
def test_make_signed_request_success(self, mock_request, provider):
"""Test successful signed request"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 0, "data": {}}
mock_request.return_value = mock_response
result = provider._make_signed_request(
"POST", "/api/test", data={"test": "data"}
)
assert result == mock_response
mock_request.assert_called_once()
# Verify the call arguments
call_args = mock_request.call_args
assert call_args[1]["method"] == "POST"
assert call_args[1]["url"] == f"https://{provider.api_url}/api/test"
assert call_args[1]["timeout"] == 30
@patch("src.rag.vikingdb_knowledge_base.requests.request")
def test_make_signed_request_with_exception(self, mock_request, provider):
"""Test signed request with exception"""
mock_request.side_effect = Exception("Network error")
with pytest.raises(ValueError, match="Request failed: Network error"):
provider._make_signed_request("GET", "/api/test")
class TestVikingDBKnowledgeBaseProviderQueryRelevantDocuments:
@pytest.fixture
def provider(self, env_vars):
return VikingDBKnowledgeBaseProvider()
def test_query_relevant_documents_empty_resources(self, provider):
"""Test querying with empty resources list"""
result = provider.query_relevant_documents("test query", [])
assert result == []
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_success(self, mock_request, provider):
"""Test successful document query"""
# Mock response
mock_response = MagicMock()
mock_response.json.return_value = {
"code": 0,
"data": {
"result_list": [
{
"doc_info": {
"doc_id": "doc123",
"doc_name": "Test Document",
},
"content": "Test content",
"score": 0.95,
}
]
},
}
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123")]
result = provider.query_relevant_documents("test query", resources)
assert len(result) == 1
assert result[0].id == "doc123"
assert result[0].title == "Test Document"
assert len(result[0].chunks) == 1
assert result[0].chunks[0].content == "Test content"
assert result[0].chunks[0].similarity == 0.95
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_with_document_filter(
self, mock_request, provider
):
"""Test document query with document ID filter"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 0, "data": {"result_list": []}}
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123#doc456")]
provider.query_relevant_documents("test query", resources)
# Verify that query_param with doc_filter was included in the request
call_args = mock_request.call_args
request_data = call_args[1]["data"]
assert "query_param" in request_data
assert "doc_filter" in request_data["query_param"]
doc_filter = request_data["query_param"]["doc_filter"]
assert doc_filter["op"] == "must"
assert doc_filter["field"] == "doc_id"
assert doc_filter["conds"] == ["doc456"]
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_api_error(self, mock_request, provider):
"""Test handling of API error response"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 1, "message": "API Error"}
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123")]
with pytest.raises(
ValueError, match="Failed to query documents from resource: API Error"
):
provider.query_relevant_documents("test query", resources)
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_json_decode_error(self, mock_request, provider):
"""Test handling of JSON decode error"""
mock_response = MagicMock()
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
mock_request.return_value = mock_response
resources = [MockResource("rag://dataset/123")]
with pytest.raises(ValueError, match="Failed to parse JSON response"):
provider.query_relevant_documents("test query", resources)
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_query_relevant_documents_multiple_resources(self, mock_request, provider):
"""Test querying multiple resources and merging results"""
# Mock responses for different resources
responses = [
{
"code": 0,
"data": {
"result_list": [
{
"doc_info": {
"doc_id": "doc1",
"doc_name": "Document 1",
},
"content": "Content 1",
"score": 0.9,
}
]
},
},
{
"code": 0,
"data": {
"result_list": [
{
"doc_info": {
"doc_id": "doc1",
"doc_name": "Document 1",
},
"content": "Content 2",
"score": 0.8,
},
{
"doc_info": {
"doc_id": "doc2",
"doc_name": "Document 2",
},
"content": "Content 3",
"score": 0.7,
},
]
},
},
]
mock_responses = [MagicMock() for _ in responses]
for i, resp in enumerate(responses):
mock_responses[i].json.return_value = resp
mock_request.side_effect = mock_responses
resources = [
MockResource("rag://dataset/123"),
MockResource("rag://dataset/456"),
]
result = provider.query_relevant_documents("test query", resources)
# Should have 2 documents: doc1 (with 2 chunks) and doc2 (with 1 chunk)
assert len(result) == 2
doc1 = next(doc for doc in result if doc.id == "doc1")
doc2 = next(doc for doc in result if doc.id == "doc2")
assert len(doc1.chunks) == 2
assert len(doc2.chunks) == 1
class TestVikingDBKnowledgeBaseProviderListResources:
@pytest.fixture
def provider(self, env_vars):
return VikingDBKnowledgeBaseProvider()
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_success(self, mock_request, provider):
"""Test successful resource listing"""
mock_response = MagicMock()
mock_response.json.return_value = {
"code": 0,
"data": {
"collection_list": [
{
"resource_id": "123",
"collection_name": "Dataset 1",
"description": "Description 1",
},
{
"resource_id": "456",
"collection_name": "Dataset 2",
"description": "Description 2",
},
]
},
}
mock_request.return_value = mock_response
result = provider.list_resources()
assert len(result) == 2
assert result[0].uri == "rag://dataset/123"
assert result[0].title == "Dataset 1"
assert result[0].description == "Description 1"
assert result[1].uri == "rag://dataset/456"
assert result[1].title == "Dataset 2"
assert result[1].description == "Description 2"
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_with_query_filter(self, mock_request, provider):
"""Test resource listing with query filter"""
mock_response = MagicMock()
mock_response.json.return_value = {
"code": 0,
"data": {
"collection_list": [
{
"resource_id": "123",
"collection_name": "Test Dataset",
"description": "Description",
},
{
"resource_id": "456",
"collection_name": "Other Dataset",
"description": "Description",
},
]
},
}
mock_request.return_value = mock_response
result = provider.list_resources("test")
# Should only return the dataset with "test" in the name
assert len(result) == 1
assert result[0].title == "Test Dataset"
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_api_error(self, mock_request, provider):
"""Test handling of API error in list_resources"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 1, "message": "API Error"}
mock_request.return_value = mock_response
with pytest.raises(Exception, match="Failed to list resources: API Error"):
provider.list_resources()
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_json_decode_error(self, mock_request, provider):
"""Test handling of JSON decode error in list_resources"""
mock_response = MagicMock()
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
mock_request.return_value = mock_response
with pytest.raises(ValueError, match="Failed to parse JSON response"):
provider.list_resources()
@patch.object(VikingDBKnowledgeBaseProvider, "_make_signed_request")
def test_list_resources_empty_response(self, mock_request, provider):
"""Test handling of empty response"""
mock_response = MagicMock()
mock_response.json.return_value = {"code": 0, "data": {"collection_list": []}}
mock_request.return_value = mock_response
result = provider.list_resources()
assert result == []