mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
test: added unit tests for rag (#298)
* test: added unit tests for rag * reformate the code
This commit is contained in:
@@ -122,12 +122,3 @@ def parse_uri(uri: str) -> tuple[str, str]:
|
||||
if parsed.scheme != "rag":
|
||||
raise ValueError(f"Invalid URI: {uri}")
|
||||
return parsed.path.split("/")[1], parsed.fragment
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uri = "rag://dataset/123#abc"
|
||||
parsed = urlparse(uri)
|
||||
print(parsed.scheme)
|
||||
print(parsed.netloc)
|
||||
print(parsed.path)
|
||||
print(parsed.fragment)
|
||||
|
||||
181
tests/unit/rag/test_ragflow.py
Normal file
181
tests/unit/rag/test_ragflow.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import requests
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.rag.ragflow import RAGFlowProvider, parse_uri
|
||||
|
||||
|
||||
# Dummy classes to mock dependencies
|
||||
class DummyResource:
|
||||
def __init__(self, uri, title="", description=""):
|
||||
self.uri = uri
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, content, similarity):
|
||||
self.content = content
|
||||
self.similarity = similarity
|
||||
|
||||
|
||||
class DummyDocument:
|
||||
def __init__(self, id, title, chunks=None):
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.chunks = chunks or []
|
||||
|
||||
|
||||
# Patch imports in ragflow.py to use dummy classes
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_imports(monkeypatch):
|
||||
import src.rag.ragflow as ragflow
|
||||
|
||||
ragflow.Resource = DummyResource
|
||||
ragflow.Chunk = DummyChunk
|
||||
ragflow.Document = DummyDocument
|
||||
yield
|
||||
|
||||
|
||||
def test_parse_uri_valid():
|
||||
uri = "rag://dataset/123#abc"
|
||||
dataset_id, document_id = parse_uri(uri)
|
||||
assert dataset_id == "123"
|
||||
assert document_id == "abc"
|
||||
|
||||
|
||||
def test_parse_uri_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
parse_uri("http://dataset/123#abc")
|
||||
|
||||
|
||||
def test_init_env_vars(monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
monkeypatch.delenv("RAGFLOW_PAGE_SIZE", raising=False)
|
||||
provider = RAGFlowProvider()
|
||||
assert provider.api_url == "http://api"
|
||||
assert provider.api_key == "key"
|
||||
assert provider.page_size == 10
|
||||
|
||||
|
||||
def test_init_page_size(monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
monkeypatch.setenv("RAGFLOW_PAGE_SIZE", "5")
|
||||
provider = RAGFlowProvider()
|
||||
assert provider.page_size == 5
|
||||
|
||||
|
||||
def test_init_missing_env(monkeypatch):
|
||||
monkeypatch.delenv("RAGFLOW_API_URL", raising=False)
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
with pytest.raises(ValueError):
|
||||
RAGFlowProvider()
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.delenv("RAGFLOW_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError):
|
||||
RAGFlowProvider()
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.post")
|
||||
def test_query_relevant_documents_success(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
resource = DummyResource("rag://dataset/123#doc456")
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": {
|
||||
"doc_aggs": [{"doc_id": "doc456", "doc_name": "Doc Title"}],
|
||||
"chunks": [
|
||||
{"document_id": "doc456", "content": "chunk text", "similarity": 0.9}
|
||||
],
|
||||
}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
docs = provider.query_relevant_documents("query", [resource])
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "doc456"
|
||||
assert docs[0].title == "Doc Title"
|
||||
assert len(docs[0].chunks) == 1
|
||||
assert docs[0].chunks[0].content == "chunk text"
|
||||
assert docs[0].chunks[0].similarity == 0.9
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.post")
|
||||
def test_query_relevant_documents_error(mock_post, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.text = "error"
|
||||
mock_post.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.query_relevant_documents("query", [])
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_success(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) == 2
|
||||
assert resources[0].uri == "rag://dataset/123"
|
||||
assert resources[0].title == "Dataset1"
|
||||
assert resources[0].description == "desc1"
|
||||
assert resources[1].uri == "rag://dataset/456"
|
||||
assert resources[1].title == "Dataset2"
|
||||
assert resources[1].description == "desc2"
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_success(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "123", "name": "Dataset1", "description": "desc1"},
|
||||
{"id": "456", "name": "Dataset2", "description": "desc2"},
|
||||
]
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
resources = provider.list_resources()
|
||||
assert len(resources) == 2
|
||||
assert resources[0].uri == "rag://dataset/123"
|
||||
assert resources[0].title == "Dataset1"
|
||||
assert resources[0].description == "desc1"
|
||||
assert resources[1].uri == "rag://dataset/456"
|
||||
assert resources[1].title == "Dataset2"
|
||||
assert resources[1].description == "desc2"
|
||||
|
||||
|
||||
@patch("src.rag.ragflow.requests.get")
|
||||
def test_list_resources_error(mock_get, monkeypatch):
|
||||
monkeypatch.setenv("RAGFLOW_API_URL", "http://api")
|
||||
monkeypatch.setenv("RAGFLOW_API_KEY", "key")
|
||||
provider = RAGFlowProvider()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "fail"
|
||||
mock_get.return_value = mock_response
|
||||
with pytest.raises(Exception):
|
||||
provider.list_resources()
|
||||
72
tests/unit/rag/test_retriever.py
Normal file
72
tests/unit/rag/test_retriever.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from src.rag.retriever import Chunk, Document, Resource, Retriever
|
||||
|
||||
|
||||
def test_chunk_init():
|
||||
chunk = Chunk(content="test content", similarity=0.9)
|
||||
assert chunk.content == "test content"
|
||||
assert chunk.similarity == 0.9
|
||||
|
||||
|
||||
def test_document_init_and_to_dict():
|
||||
chunk1 = Chunk(content="chunk1", similarity=0.8)
|
||||
chunk2 = Chunk(content="chunk2", similarity=0.7)
|
||||
doc = Document(
|
||||
id="doc1", url="http://example.com", title="Title", chunks=[chunk1, chunk2]
|
||||
)
|
||||
assert doc.id == "doc1"
|
||||
assert doc.url == "http://example.com"
|
||||
assert doc.title == "Title"
|
||||
assert doc.chunks == [chunk1, chunk2]
|
||||
d = doc.to_dict()
|
||||
assert d["id"] == "doc1"
|
||||
assert d["content"] == "chunk1\n\nchunk2"
|
||||
assert d["url"] == "http://example.com"
|
||||
assert d["title"] == "Title"
|
||||
|
||||
|
||||
def test_document_to_dict_optional_fields():
|
||||
chunk = Chunk(content="only chunk", similarity=1.0)
|
||||
doc = Document(id="doc2", chunks=[chunk])
|
||||
d = doc.to_dict()
|
||||
assert d["id"] == "doc2"
|
||||
assert d["content"] == "only chunk"
|
||||
assert "url" not in d
|
||||
assert "title" not in d
|
||||
|
||||
|
||||
def test_resource_model():
|
||||
resource = Resource(uri="uri1", title="Resource Title")
|
||||
assert resource.uri == "uri1"
|
||||
assert resource.title == "Resource Title"
|
||||
assert resource.description == ""
|
||||
|
||||
|
||||
def test_resource_model_with_description():
|
||||
resource = Resource(uri="uri2", title="Resource2", description="desc")
|
||||
assert resource.description == "desc"
|
||||
|
||||
|
||||
def test_retriever_abstract_methods():
|
||||
class DummyRetriever(Retriever):
|
||||
def list_resources(self, query=None):
|
||||
return [Resource(uri="uri", title="title")]
|
||||
|
||||
def query_relevant_documents(self, query, resources=[]):
|
||||
return [Document(id="id", chunks=[])]
|
||||
|
||||
retriever = DummyRetriever()
|
||||
resources = retriever.list_resources()
|
||||
assert isinstance(resources, list)
|
||||
assert isinstance(resources[0], Resource)
|
||||
docs = retriever.query_relevant_documents("query", resources)
|
||||
assert isinstance(docs, list)
|
||||
assert isinstance(docs[0], Document)
|
||||
|
||||
|
||||
def test_retriever_cannot_instantiate():
|
||||
with pytest.raises(TypeError):
|
||||
Retriever()
|
||||
Reference in New Issue
Block a user