From ee1af787675917f4130cd1826b4d61210e90c6b1 Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Wed, 11 Jun 2025 19:46:08 +0800 Subject: [PATCH] test: added unit tests for rag (#298) * test: added unit tests for rag * reformate the code --- src/rag/ragflow.py | 9 -- tests/unit/rag/test_ragflow.py | 181 +++++++++++++++++++++++++++++++ tests/unit/rag/test_retriever.py | 72 ++++++++++++ 3 files changed, 253 insertions(+), 9 deletions(-) create mode 100644 tests/unit/rag/test_ragflow.py create mode 100644 tests/unit/rag/test_retriever.py diff --git a/src/rag/ragflow.py b/src/rag/ragflow.py index b95ba3b..529ff2c 100644 --- a/src/rag/ragflow.py +++ b/src/rag/ragflow.py @@ -122,12 +122,3 @@ def parse_uri(uri: str) -> tuple[str, str]: if parsed.scheme != "rag": raise ValueError(f"Invalid URI: {uri}") return parsed.path.split("/")[1], parsed.fragment - - -if __name__ == "__main__": - uri = "rag://dataset/123#abc" - parsed = urlparse(uri) - print(parsed.scheme) - print(parsed.netloc) - print(parsed.path) - print(parsed.fragment) diff --git a/tests/unit/rag/test_ragflow.py b/tests/unit/rag/test_ragflow.py new file mode 100644 index 0000000..202cb5b --- /dev/null +++ b/tests/unit/rag/test_ragflow.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import os +import pytest +import requests +from unittest.mock import patch, MagicMock +from src.rag.ragflow import RAGFlowProvider, parse_uri + + +# Dummy classes to mock dependencies +class DummyResource: + def __init__(self, uri, title="", description=""): + self.uri = uri + self.title = title + self.description = description + + +class DummyChunk: + def __init__(self, content, similarity): + self.content = content + self.similarity = similarity + + +class DummyDocument: + def __init__(self, id, title, chunks=None): + self.id = id + self.title = title + self.chunks = chunks or [] + + +# Patch imports in ragflow.py to use dummy classes +@pytest.fixture(autouse=True) +def patch_imports(monkeypatch): + import src.rag.ragflow as ragflow + + ragflow.Resource = DummyResource + ragflow.Chunk = DummyChunk + ragflow.Document = DummyDocument + yield + + +def test_parse_uri_valid(): + uri = "rag://dataset/123#abc" + dataset_id, document_id = parse_uri(uri) + assert dataset_id == "123" + assert document_id == "abc" + + +def test_parse_uri_invalid(): + with pytest.raises(ValueError): + parse_uri("http://dataset/123#abc") + + +def test_init_env_vars(monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + monkeypatch.delenv("RAGFLOW_PAGE_SIZE", raising=False) + provider = RAGFlowProvider() + assert provider.api_url == "http://api" + assert provider.api_key == "key" + assert provider.page_size == 10 + + +def test_init_page_size(monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + monkeypatch.setenv("RAGFLOW_PAGE_SIZE", "5") + provider = RAGFlowProvider() + assert provider.page_size == 5 + + +def test_init_missing_env(monkeypatch): + monkeypatch.delenv("RAGFLOW_API_URL", raising=False) + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + with pytest.raises(ValueError): + RAGFlowProvider() + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.delenv("RAGFLOW_API_KEY", raising=False) + with pytest.raises(ValueError): + RAGFlowProvider() + + +@patch("src.rag.ragflow.requests.post") +def test_query_relevant_documents_success(mock_post, monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + provider = RAGFlowProvider() + resource = DummyResource("rag://dataset/123#doc456") + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": { + "doc_aggs": [{"doc_id": "doc456", "doc_name": "Doc Title"}], + "chunks": [ + {"document_id": "doc456", "content": "chunk text", "similarity": 0.9} + ], + } + } + mock_post.return_value = mock_response + docs = provider.query_relevant_documents("query", [resource]) + assert len(docs) == 1 + assert docs[0].id == "doc456" + assert docs[0].title == "Doc Title" + assert len(docs[0].chunks) == 1 + assert docs[0].chunks[0].content == "chunk text" + assert docs[0].chunks[0].similarity == 0.9 + + +@patch("src.rag.ragflow.requests.post") +def test_query_relevant_documents_error(mock_post, monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + provider = RAGFlowProvider() + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "error" + mock_post.return_value = mock_response + with pytest.raises(Exception): + provider.query_relevant_documents("query", []) + + +@patch("src.rag.ragflow.requests.get") +def test_list_resources_success(mock_get, monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + provider = RAGFlowProvider() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + {"id": "123", "name": "Dataset1", "description": "desc1"}, + {"id": "456", "name": "Dataset2", "description": "desc2"}, + ] + } + mock_get.return_value = mock_response + resources = provider.list_resources() + assert len(resources) == 2 + assert resources[0].uri == "rag://dataset/123" + assert resources[0].title == "Dataset1" + assert resources[0].description == "desc1" + assert resources[1].uri == "rag://dataset/456" + assert resources[1].title == "Dataset2" + assert resources[1].description == "desc2" + + +@patch("src.rag.ragflow.requests.get") +def test_list_resources_success(mock_get, monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + provider = RAGFlowProvider() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + {"id": "123", "name": "Dataset1", "description": "desc1"}, + {"id": "456", "name": "Dataset2", "description": "desc2"}, + ] + } + mock_get.return_value = mock_response + resources = provider.list_resources() + assert len(resources) == 2 + assert resources[0].uri == "rag://dataset/123" + assert resources[0].title == "Dataset1" + assert resources[0].description == "desc1" + assert resources[1].uri == "rag://dataset/456" + assert resources[1].title == "Dataset2" + assert resources[1].description == "desc2" + + +@patch("src.rag.ragflow.requests.get") +def test_list_resources_error(mock_get, monkeypatch): + monkeypatch.setenv("RAGFLOW_API_URL", "http://api") + monkeypatch.setenv("RAGFLOW_API_KEY", "key") + provider = RAGFlowProvider() + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "fail" + mock_get.return_value = mock_response + with pytest.raises(Exception): + provider.list_resources() diff --git a/tests/unit/rag/test_retriever.py b/tests/unit/rag/test_retriever.py new file mode 100644 index 0000000..4c4964d --- /dev/null +++ b/tests/unit/rag/test_retriever.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import pytest +from src.rag.retriever import Chunk, Document, Resource, Retriever + + +def test_chunk_init(): + chunk = Chunk(content="test content", similarity=0.9) + assert chunk.content == "test content" + assert chunk.similarity == 0.9 + + +def test_document_init_and_to_dict(): + chunk1 = Chunk(content="chunk1", similarity=0.8) + chunk2 = Chunk(content="chunk2", similarity=0.7) + doc = Document( + id="doc1", url="http://example.com", title="Title", chunks=[chunk1, chunk2] + ) + assert doc.id == "doc1" + assert doc.url == "http://example.com" + assert doc.title == "Title" + assert doc.chunks == [chunk1, chunk2] + d = doc.to_dict() + assert d["id"] == "doc1" + assert d["content"] == "chunk1\n\nchunk2" + assert d["url"] == "http://example.com" + assert d["title"] == "Title" + + +def test_document_to_dict_optional_fields(): + chunk = Chunk(content="only chunk", similarity=1.0) + doc = Document(id="doc2", chunks=[chunk]) + d = doc.to_dict() + assert d["id"] == "doc2" + assert d["content"] == "only chunk" + assert "url" not in d + assert "title" not in d + + +def test_resource_model(): + resource = Resource(uri="uri1", title="Resource Title") + assert resource.uri == "uri1" + assert resource.title == "Resource Title" + assert resource.description == "" + + +def test_resource_model_with_description(): + resource = Resource(uri="uri2", title="Resource2", description="desc") + assert resource.description == "desc" + + +def test_retriever_abstract_methods(): + class DummyRetriever(Retriever): + def list_resources(self, query=None): + return [Resource(uri="uri", title="title")] + + def query_relevant_documents(self, query, resources=[]): + return [Document(id="id", chunks=[])] + + retriever = DummyRetriever() + resources = retriever.list_resources() + assert isinstance(resources, list) + assert isinstance(resources[0], Resource) + docs = retriever.query_relevant_documents("query", resources) + assert isinstance(docs, list) + assert isinstance(docs[0], Document) + + +def test_retriever_cannot_instantiate(): + with pytest.raises(TypeError): + Retriever()