From 4c2fe2e7f54f6823288e624ef5f739fbc3d71f7c Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Thu, 12 Jun 2025 20:43:32 +0800 Subject: [PATCH] test: add more unit tests of tools (#315) * test: add more test on test_tts.py * test: add unit test of search and retriever in tools * test: remove the main code of search.py * test: add the travily_search unit test * reformate the codes * test: add unit tests of tools * Added the pytest-asyncio dependency * added the license header of test_tavily_search_api_wrapper.py --- pyproject.toml | 1 + src/rag/__init__.py | 4 +- src/tools/retriever.py | 15 - src/tools/search.py | 16 +- .../tavily_search_api_wrapper.py | 10 +- .../tavily_search_results_with_images.py | 3 + tests/integration/test_tts.py | 17 ++ tests/unit/tools/test_crawl.py | 110 ++++++++ tests/unit/tools/test_decorators.py | 121 ++++++++ tests/unit/tools/test_python_repl.py | 147 ++++++++++ tests/unit/tools/test_search.py | 54 ++++ .../tools/test_tavily_search_api_wrapper.py | 207 ++++++++++++++ .../test_tavily_search_results_with_images.py | 265 ++++++++++++++++++ tests/unit/tools/test_tools_retriever.py | 122 ++++++++ 14 files changed, 1057 insertions(+), 35 deletions(-) create mode 100644 tests/unit/tools/test_crawl.py create mode 100644 tests/unit/tools/test_decorators.py create mode 100644 tests/unit/tools/test_python_repl.py create mode 100644 tests/unit/tools/test_search.py create mode 100644 tests/unit/tools/test_tavily_search_api_wrapper.py create mode 100644 tests/unit/tools/test_tavily_search_results_with_images.py create mode 100644 tests/unit/tools/test_tools_retriever.py diff --git a/pyproject.toml b/pyproject.toml index 2bb2322..7328215 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ test = [ "pytest>=7.4.0", "pytest-cov>=4.1.0", + "pytest-asyncio>=1.0.0", ] [tool.pytest.ini_options] diff --git a/src/rag/__init__.py b/src/rag/__init__.py index c325016..e271cfc 100644 --- a/src/rag/__init__.py +++ b/src/rag/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT -from .retriever import Retriever, Document, Resource +from .retriever import Retriever, Document, Resource, Chunk from .ragflow import RAGFlowProvider from .builder import build_retriever -__all__ = [Retriever, Document, Resource, RAGFlowProvider, build_retriever] +__all__ = [Retriever, Document, Resource, RAGFlowProvider, Chunk, build_retriever] diff --git a/src/tools/retriever.py b/src/tools/retriever.py index fc0bf93..12dfd49 100644 --- a/src/tools/retriever.py +++ b/src/tools/retriever.py @@ -60,18 +60,3 @@ def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None: if not retriever: return None return RetrieverTool(retriever=retriever, resources=resources) - - -if __name__ == "__main__": - resources = [ - Resource( - uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0", - title="西游记", - description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。", - ) - ] - retriever_tool = get_retriever_tool(resources) - print(retriever_tool.name) - print(retriever_tool.description) - print(retriever_tool.args) - print(retriever_tool.invoke("三打白骨精")) diff --git a/src/tools/search.py b/src/tools/search.py index fa8445d..bbe4fa8 100644 --- a/src/tools/search.py +++ b/src/tools/search.py @@ -36,7 +36,10 @@ def get_web_search_tool(max_search_results: int): include_image_descriptions=True, ) elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value: - return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results) + return LoggedDuckDuckGoSearch( + name="web_search", + num_results=max_search_results, + ) elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value: return LoggedBraveSearch( name="web_search", @@ -56,14 +59,3 @@ def get_web_search_tool(max_search_results: int): ) else: raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}") - - -if __name__ == "__main__": - results = LoggedDuckDuckGoSearch( - name="web_search", max_results=3, output_format="list" - ) - print(results.name) - print(results.description) - print(results.args) - # .invoke("cute panda") - # print(json.dumps(results, indent=2, ensure_ascii=False)) diff --git a/src/tools/tavily_search/tavily_search_api_wrapper.py b/src/tools/tavily_search/tavily_search_api_wrapper.py index ef19728..191d3c2 100644 --- a/src/tools/tavily_search/tavily_search_api_wrapper.py +++ b/src/tools/tavily_search/tavily_search_api_wrapper.py @@ -1,3 +1,7 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + import json from typing import Dict, List, Optional @@ -107,9 +111,3 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper): } clean_results.append(clean_result) return clean_results - - -if __name__ == "__main__": - wrapper = EnhancedTavilySearchAPIWrapper() - results = wrapper.raw_results("cute panda", include_images=True) - print(json.dumps(results, indent=2, ensure_ascii=False)) diff --git a/src/tools/tavily_search/tavily_search_results_with_images.py b/src/tools/tavily_search/tavily_search_results_with_images.py index 915538a..e18680d 100644 --- a/src/tools/tavily_search/tavily_search_results_with_images.py +++ b/src/tools/tavily_search/tavily_search_results_with_images.py @@ -1,3 +1,6 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + import json from typing import Dict, List, Optional, Tuple, Union diff --git a/tests/integration/test_tts.py b/tests/integration/test_tts.py index a22405d..1066c95 100644 --- a/tests/integration/test_tts.py +++ b/tests/integration/test_tts.py @@ -229,3 +229,20 @@ class TestVolcengineTTS: args, kwargs = mock_post.call_args request_json = json.loads(args[1]) assert request_json["user"]["uid"] == str(mock_uuid_value) + + @patch("src.tools.tts.requests.post") + def test_text_to_speech_request_exception(self, mock_post): + """Test error handling when requests.post raises an exception.""" + # Mock requests.post to raise an exception + mock_post.side_effect = Exception("Network error") + # Create TTS client + tts = VolcengineTTS( + appid="test_appid", + access_token="test_token", + ) + # Call the method + result = tts.text_to_speech("Hello, world!") + # Verify the result + assert result["success"] is False + assert result["error"] == "Network error" + assert result["audio_data"] is None diff --git a/tests/unit/tools/test_crawl.py b/tests/unit/tools/test_crawl.py new file mode 100644 index 0000000..405c44f --- /dev/null +++ b/tests/unit/tools/test_crawl.py @@ -0,0 +1,110 @@ +import pytest +from unittest.mock import Mock, patch +from src.tools.crawl import crawl_tool + + +class TestCrawlTool: + + @patch("src.tools.crawl.Crawler") + def test_crawl_tool_success(self, mock_crawler_class): + # Arrange + mock_crawler = Mock() + mock_article = Mock() + mock_article.to_markdown.return_value = ( + "# Test Article\nThis is test content." * 100 + ) + mock_crawler.crawl.return_value = mock_article + mock_crawler_class.return_value = mock_crawler + + url = "https://example.com" + + # Act + result = crawl_tool(url) + + # Assert + assert isinstance(result, dict) + assert result["url"] == url + assert "crawled_content" in result + assert len(result["crawled_content"]) <= 1000 + mock_crawler_class.assert_called_once() + mock_crawler.crawl.assert_called_once_with(url) + mock_article.to_markdown.assert_called_once() + + @patch("src.tools.crawl.Crawler") + def test_crawl_tool_short_content(self, mock_crawler_class): + # Arrange + mock_crawler = Mock() + mock_article = Mock() + short_content = "Short content" + mock_article.to_markdown.return_value = short_content + mock_crawler.crawl.return_value = mock_article + mock_crawler_class.return_value = mock_crawler + + url = "https://example.com" + + # Act + result = crawl_tool(url) + + # Assert + assert result["crawled_content"] == short_content + + @patch("src.tools.crawl.Crawler") + @patch("src.tools.crawl.logger") + def test_crawl_tool_crawler_exception(self, mock_logger, mock_crawler_class): + # Arrange + mock_crawler = Mock() + mock_crawler.crawl.side_effect = Exception("Network error") + mock_crawler_class.return_value = mock_crawler + + url = "https://example.com" + + # Act + result = crawl_tool(url) + + # Assert + assert isinstance(result, str) + assert "Failed to crawl" in result + assert "Network error" in result + mock_logger.error.assert_called_once() + + @patch("src.tools.crawl.Crawler") + @patch("src.tools.crawl.logger") + def test_crawl_tool_crawler_instantiation_exception( + self, mock_logger, mock_crawler_class + ): + # Arrange + mock_crawler_class.side_effect = Exception("Crawler init error") + + url = "https://example.com" + + # Act + result = crawl_tool(url) + + # Assert + assert isinstance(result, str) + assert "Failed to crawl" in result + assert "Crawler init error" in result + mock_logger.error.assert_called_once() + + @patch("src.tools.crawl.Crawler") + @patch("src.tools.crawl.logger") + def test_crawl_tool_markdown_conversion_exception( + self, mock_logger, mock_crawler_class + ): + # Arrange + mock_crawler = Mock() + mock_article = Mock() + mock_article.to_markdown.side_effect = Exception("Markdown conversion error") + mock_crawler.crawl.return_value = mock_article + mock_crawler_class.return_value = mock_crawler + + url = "https://example.com" + + # Act + result = crawl_tool(url) + + # Assert + assert isinstance(result, str) + assert "Failed to crawl" in result + assert "Markdown conversion error" in result + mock_logger.error.assert_called_once() diff --git a/tests/unit/tools/test_decorators.py b/tests/unit/tools/test_decorators.py new file mode 100644 index 0000000..4dbc2b5 --- /dev/null +++ b/tests/unit/tools/test_decorators.py @@ -0,0 +1,121 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import pytest +import logging +from unittest.mock import Mock, call, patch, MagicMock +from src.tools.decorators import LoggedToolMixin, create_logged_tool + + +class MockBaseTool: + """Mock base tool class for testing.""" + + def _run(self, *args, **kwargs): + return "base_result" + + +class TestLoggedToolMixin: + + def test_run_calls_log_operation(self): + """Test that _run calls _log_operation with correct parameters.""" + # Create a logged tool instance + LoggedTool = create_logged_tool(MockBaseTool) + tool = LoggedTool() + + # Mock the _log_operation method + tool._log_operation = Mock() + + # Call _run with test parameters + args = ("arg1", "arg2") + kwargs = {"key1": "value1", "key2": "value2"} + tool._run(*args, **kwargs) + + # Verify _log_operation was called with correct parameters + tool._log_operation.assert_called_once_with("_run", *args, **kwargs) + + def test_run_calls_super_run(self): + """Test that _run calls the parent class _run method.""" + # Create a logged tool instance + LoggedTool = create_logged_tool(MockBaseTool) + tool = LoggedTool() + + # Mock the parent _run method + with patch.object( + MockBaseTool, "_run", return_value="mocked_result" + ) as mock_super_run: + args = ("arg1", "arg2") + kwargs = {"key1": "value1"} + result = tool._run(*args, **kwargs) + + # Verify super()._run was called with correct parameters + mock_super_run.assert_called_once_with(*args, **kwargs) + # Verify the result is returned + assert result == "mocked_result" + + def test_run_logs_result(self): + """Test that _run logs the result with debug level.""" + LoggedTool = create_logged_tool(MockBaseTool) + tool = LoggedTool() + + with patch("src.tools.decorators.logger.debug") as mock_debug: + result = tool._run("test_arg") + + # Verify debug log was called with correct message + mock_debug.assert_has_calls( + [ + call("Tool MockBaseTool._run called with parameters: test_arg"), + call("Tool MockBaseTool returned: base_result"), + ] + ) + + def test_run_returns_super_result(self): + """Test that _run returns the result from parent class.""" + LoggedTool = create_logged_tool(MockBaseTool) + tool = LoggedTool() + + result = tool._run() + assert result == "base_result" + + def test_run_with_no_args(self): + """Test _run method with no arguments.""" + LoggedTool = create_logged_tool(MockBaseTool) + tool = LoggedTool() + + with patch("src.tools.decorators.logger.debug") as mock_debug: + tool._log_operation = Mock() + + result = tool._run() + + # Verify _log_operation called with no args + tool._log_operation.assert_called_once_with("_run") + # Verify result logging + mock_debug.assert_called_once() + assert result == "base_result" + + def test_run_with_mixed_args_kwargs(self): + """Test _run method with both positional and keyword arguments.""" + LoggedTool = create_logged_tool(MockBaseTool) + tool = LoggedTool() + + tool._log_operation = Mock() + + args = ("pos1", "pos2") + kwargs = {"kw1": "val1", "kw2": "val2"} + result = tool._run(*args, **kwargs) + + # Verify all arguments passed correctly + tool._log_operation.assert_called_once_with("_run", *args, **kwargs) + assert result == "base_result" + + def test_run_class_name_replacement(self): + """Test that class name 'Logged' prefix is correctly removed in logging.""" + LoggedTool = create_logged_tool(MockBaseTool) + tool = LoggedTool() + + with patch("src.tools.decorators.logger.debug") as mock_debug: + tool._run() + + # Verify the logged class name has 'Logged' prefix removed + call_args = mock_debug.call_args[0][0] + assert "Tool MockBaseTool returned:" in call_args + assert "LoggedMockBaseTool" not in call_args diff --git a/tests/unit/tools/test_python_repl.py b/tests/unit/tools/test_python_repl.py new file mode 100644 index 0000000..963744e --- /dev/null +++ b/tests/unit/tools/test_python_repl.py @@ -0,0 +1,147 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import pytest +from unittest.mock import Mock, patch, MagicMock +from src.tools.python_repl import python_repl_tool + + +class TestPythonReplTool: + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_successful_code_execution(self, mock_logger, mock_repl): + # Arrange + code = "print('Hello, World!')" + expected_output = "Hello, World!\n" + mock_repl.run.return_value = expected_output + + # Act + result = python_repl_tool(code) + + # Assert + mock_repl.run.assert_called_once_with(code) + mock_logger.info.assert_called_with("Code execution successful") + assert "Successfully executed:" in result + assert code in result + assert expected_output in result + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_invalid_input_type(self, mock_logger, mock_repl): + # Arrange + invalid_code = 123 + + # Act & Assert - expect ValidationError from LangChain + with pytest.raises(Exception) as exc_info: + python_repl_tool(invalid_code) + + # Verify that it's a validation error + assert "ValidationError" in str( + type(exc_info.value) + ) or "validation error" in str(exc_info.value) + + # The REPL should not be called since validation fails first + mock_repl.run.assert_not_called() + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_code_execution_with_error_in_result(self, mock_logger, mock_repl): + # Arrange + code = "invalid_function()" + error_result = "NameError: name 'invalid_function' is not defined" + mock_repl.run.return_value = error_result + + # Act + result = python_repl_tool(code) + + # Assert + mock_repl.run.assert_called_once_with(code) + mock_logger.error.assert_called_with(error_result) + assert "Error executing code:" in result + assert code in result + assert error_result in result + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_code_execution_with_exception_in_result(self, mock_logger, mock_repl): + # Arrange + code = "1/0" + exception_result = "ZeroDivisionError: division by zero" + mock_repl.run.return_value = exception_result + + # Act + result = python_repl_tool(code) + + # Assert + mock_repl.run.assert_called_once_with(code) + mock_logger.error.assert_called_with(exception_result) + assert "Error executing code:" in result + assert code in result + assert exception_result in result + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_code_execution_raises_exception(self, mock_logger, mock_repl): + # Arrange + code = "print('test')" + exception = RuntimeError("REPL failed") + mock_repl.run.side_effect = exception + + # Act + result = python_repl_tool(code) + + # Assert + mock_repl.run.assert_called_once_with(code) + mock_logger.error.assert_called_with(repr(exception)) + assert "Error executing code:" in result + assert code in result + assert repr(exception) in result + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_successful_execution_with_calculation(self, mock_logger, mock_repl): + # Arrange + code = "result = 2 + 3\nprint(result)" + expected_output = "5\n" + mock_repl.run.return_value = expected_output + + # Act + result = python_repl_tool(code) + + # Assert + mock_repl.run.assert_called_once_with(code) + mock_logger.info.assert_any_call("Executing Python code") + mock_logger.info.assert_any_call("Code execution successful") + assert "Successfully executed:" in result + assert code in result + assert expected_output in result + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_empty_string_code(self, mock_logger, mock_repl): + # Arrange + code = "" + mock_repl.run.return_value = "" + + # Act + result = python_repl_tool(code) + + # Assert + mock_repl.run.assert_called_once_with(code) + mock_logger.info.assert_called_with("Code execution successful") + assert "Successfully executed:" in result + + @patch("src.tools.python_repl.repl") + @patch("src.tools.python_repl.logger") + def test_logging_calls(self, mock_logger, mock_repl): + # Arrange + code = "x = 1" + mock_repl.run.return_value = "" + + # Act + python_repl_tool(code) + + # Assert + mock_logger.info.assert_any_call("Executing Python code") + mock_logger.info.assert_any_call("Code execution successful") diff --git a/tests/unit/tools/test_search.py b/tests/unit/tools/test_search.py new file mode 100644 index 0000000..cc914a2 --- /dev/null +++ b/tests/unit/tools/test_search.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import os +import pytest +from unittest.mock import patch, MagicMock +from src.tools.search import get_web_search_tool +from src.config import SearchEngine + + +class TestGetWebSearchTool: + + @patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value) + def test_get_web_search_tool_tavily(self): + tool = get_web_search_tool(max_search_results=5) + assert tool.name == "web_search" + assert tool.max_results == 5 + assert tool.include_raw_content is True + assert tool.include_images is True + assert tool.include_image_descriptions is True + + @patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.DUCKDUCKGO.value) + def test_get_web_search_tool_duckduckgo(self): + tool = get_web_search_tool(max_search_results=3) + assert tool.name == "web_search" + assert tool.max_results == 3 + + @patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value) + @patch.dict(os.environ, {"BRAVE_SEARCH_API_KEY": "test_api_key"}) + def test_get_web_search_tool_brave(self): + tool = get_web_search_tool(max_search_results=4) + assert tool.name == "web_search" + assert tool.search_wrapper.api_key == "test_api_key" + + @patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.ARXIV.value) + def test_get_web_search_tool_arxiv(self): + tool = get_web_search_tool(max_search_results=2) + assert tool.name == "web_search" + assert tool.api_wrapper.top_k_results == 2 + assert tool.api_wrapper.load_max_docs == 2 + assert tool.api_wrapper.load_all_available_meta is True + + @patch("src.tools.search.SELECTED_SEARCH_ENGINE", "unsupported_engine") + def test_get_web_search_tool_unsupported_engine(self): + with pytest.raises( + ValueError, match="Unsupported search engine: unsupported_engine" + ): + get_web_search_tool(max_search_results=1) + + @patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value) + @patch.dict(os.environ, {}, clear=True) + def test_get_web_search_tool_brave_no_api_key(self): + tool = get_web_search_tool(max_search_results=1) + assert tool.search_wrapper.api_key == "" diff --git a/tests/unit/tools/test_tavily_search_api_wrapper.py b/tests/unit/tools/test_tavily_search_api_wrapper.py new file mode 100644 index 0000000..37d6242 --- /dev/null +++ b/tests/unit/tools/test_tavily_search_api_wrapper.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT +import json +import pytest +from unittest.mock import Mock, patch, AsyncMock, MagicMock +import aiohttp +import requests +from src.tools.tavily_search.tavily_search_api_wrapper import ( + EnhancedTavilySearchAPIWrapper, +) + + +class TestEnhancedTavilySearchAPIWrapper: + + @pytest.fixture + def wrapper(self): + with patch( + "src.tools.tavily_search.tavily_search_api_wrapper.OriginalTavilySearchAPIWrapper" + ): + wrapper = EnhancedTavilySearchAPIWrapper(tavily_api_key="dummy-key") + # The parent class is mocked, so initialization won't fail + return wrapper + + @pytest.fixture + def mock_response_data(self): + return { + "results": [ + { + "title": "Test Title", + "url": "https://example.com", + "content": "Test content", + "score": 0.9, + "raw_content": "Raw test content", + } + ], + "images": [ + { + "url": "https://example.com/image.jpg", + "description": "Test image description", + } + ], + } + + @patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post") + def test_raw_results_success(self, mock_post, wrapper, mock_response_data): + mock_response = Mock() + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = wrapper.raw_results("test query", max_results=10) + + assert result == mock_response_data + mock_post.assert_called_once() + call_args = mock_post.call_args + assert "json" in call_args.kwargs + assert call_args.kwargs["json"]["query"] == "test query" + assert call_args.kwargs["json"]["max_results"] == 10 + + @patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post") + def test_raw_results_with_all_parameters( + self, mock_post, wrapper, mock_response_data + ): + mock_response = Mock() + mock_response.json.return_value = mock_response_data + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = wrapper.raw_results( + "test query", + max_results=3, + search_depth="basic", + include_domains=["example.com"], + exclude_domains=["spam.com"], + include_answer=True, + include_raw_content=True, + include_images=True, + include_image_descriptions=True, + ) + + assert result == mock_response_data + call_args = mock_post.call_args + params = call_args.kwargs["json"] + assert params["include_domains"] == ["example.com"] + assert params["exclude_domains"] == ["spam.com"] + assert params["include_answer"] is True + assert params["include_raw_content"] is True + + @patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post") + def test_raw_results_http_error(self, mock_post, wrapper): + mock_response = Mock() + mock_response.raise_for_status.side_effect = requests.HTTPError("API Error") + mock_post.return_value = mock_response + + with pytest.raises(requests.HTTPError): + wrapper.raw_results("test query") + + @pytest.mark.asyncio + async def test_raw_results_async_success(self, wrapper, mock_response_data): + # Create a mock that acts as both the response and its context manager + mock_response_cm = AsyncMock() + mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm) + mock_response_cm.__aexit__ = AsyncMock(return_value=None) + mock_response_cm.status = 200 + mock_response_cm.text = AsyncMock(return_value=json.dumps(mock_response_data)) + + # Create mock session that returns the context manager + mock_session = AsyncMock() + mock_session.post = MagicMock( + return_value=mock_response_cm + ) # Use MagicMock, not AsyncMock + + # Create mock session class + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_cm.__aexit__ = AsyncMock(return_value=None) + + with patch( + "src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession", + return_value=mock_session_cm, + ): + result = await wrapper.raw_results_async("test query") + + assert result == mock_response_data + + @pytest.mark.asyncio + async def test_raw_results_async_error(self, wrapper): + # Create a mock that acts as both the response and its context manager + mock_response_cm = AsyncMock() + mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm) + mock_response_cm.__aexit__ = AsyncMock(return_value=None) + mock_response_cm.status = 400 + mock_response_cm.reason = "Bad Request" + + # Create mock session that returns the context manager + mock_session = AsyncMock() + mock_session.post = MagicMock( + return_value=mock_response_cm + ) # Use MagicMock, not AsyncMock + + # Create mock session class + mock_session_cm = AsyncMock() + mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_cm.__aexit__ = AsyncMock(return_value=None) + + with patch( + "src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession", + return_value=mock_session_cm, + ): + with pytest.raises(Exception, match="Error 400: Bad Request"): + await wrapper.raw_results_async("test query") + + def test_clean_results_with_images(self, wrapper, mock_response_data): + result = wrapper.clean_results_with_images(mock_response_data) + + assert len(result) == 2 + + # Test page result + page_result = result[0] + assert page_result["type"] == "page" + assert page_result["title"] == "Test Title" + assert page_result["url"] == "https://example.com" + assert page_result["content"] == "Test content" + assert page_result["score"] == 0.9 + assert page_result["raw_content"] == "Raw test content" + + # Test image result + image_result = result[1] + assert image_result["type"] == "image" + assert image_result["image_url"] == "https://example.com/image.jpg" + assert image_result["image_description"] == "Test image description" + + def test_clean_results_without_raw_content(self, wrapper): + data = { + "results": [ + { + "title": "Test Title", + "url": "https://example.com", + "content": "Test content", + "score": 0.9, + } + ], + "images": [], + } + + result = wrapper.clean_results_with_images(data) + + assert len(result) == 1 + assert "raw_content" not in result[0] + + def test_clean_results_empty_images(self, wrapper): + data = { + "results": [ + { + "title": "Test Title", + "url": "https://example.com", + "content": "Test content", + "score": 0.9, + } + ], + "images": [], + } + + result = wrapper.clean_results_with_images(data) + + assert len(result) == 1 + assert result[0]["type"] == "page" diff --git a/tests/unit/tools/test_tavily_search_results_with_images.py b/tests/unit/tools/test_tavily_search_results_with_images.py new file mode 100644 index 0000000..963dbf1 --- /dev/null +++ b/tests/unit/tools/test_tavily_search_results_with_images.py @@ -0,0 +1,265 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import json +import pytest +from unittest.mock import Mock, patch, AsyncMock +from typing import Dict, Any +from src.tools.tavily_search.tavily_search_results_with_images import ( + TavilySearchResultsWithImages, +) +from src.tools.tavily_search.tavily_search_api_wrapper import ( + EnhancedTavilySearchAPIWrapper, +) + + +class TestTavilySearchResultsWithImages: + + @pytest.fixture + def mock_api_wrapper(self): + """Create a mock API wrapper.""" + wrapper = Mock(spec=EnhancedTavilySearchAPIWrapper) + return wrapper + + @pytest.fixture + def search_tool(self, mock_api_wrapper): + """Create a TavilySearchResultsWithImages instance with mocked dependencies.""" + tool = TavilySearchResultsWithImages( + max_results=5, + include_answer=True, + include_raw_content=True, + include_images=True, + include_image_descriptions=True, + ) + tool.api_wrapper = mock_api_wrapper + return tool + + @pytest.fixture + def sample_raw_results(self): + """Sample raw results from Tavily API.""" + return { + "query": "test query", + "answer": "Test answer", + "images": ["https://example.com/image1.jpg"], + "results": [ + { + "title": "Test Title", + "url": "https://example.com", + "content": "Test content", + "score": 0.95, + "raw_content": "Raw test content", + } + ], + "response_time": 1.5, + } + + @pytest.fixture + def sample_cleaned_results(self): + """Sample cleaned results.""" + return [ + { + "title": "Test Title", + "url": "https://example.com", + "content": "Test content", + } + ] + + def test_init_default_values(self): + """Test initialization with default values.""" + tool = TavilySearchResultsWithImages() + assert tool.include_image_descriptions is False + assert isinstance(tool.api_wrapper, EnhancedTavilySearchAPIWrapper) + + def test_init_custom_values(self): + """Test initialization with custom values.""" + tool = TavilySearchResultsWithImages( + max_results=10, include_image_descriptions=True + ) + assert tool.max_results == 10 + assert tool.include_image_descriptions is True + + @patch("builtins.print") + def test_run_success( + self, + mock_print, + search_tool, + mock_api_wrapper, + sample_raw_results, + sample_cleaned_results, + ): + """Test successful synchronous run.""" + mock_api_wrapper.raw_results.return_value = sample_raw_results + mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results + + result, raw = search_tool._run("test query") + + assert result == sample_cleaned_results + assert raw == sample_raw_results + + mock_api_wrapper.raw_results.assert_called_once_with( + "test query", + search_tool.max_results, + search_tool.search_depth, + search_tool.include_domains, + search_tool.exclude_domains, + search_tool.include_answer, + search_tool.include_raw_content, + search_tool.include_images, + search_tool.include_image_descriptions, + ) + + mock_api_wrapper.clean_results_with_images.assert_called_once_with( + sample_raw_results + ) + mock_print.assert_called_once() + + @patch("builtins.print") + def test_run_exception(self, mock_print, search_tool, mock_api_wrapper): + """Test synchronous run with exception.""" + mock_api_wrapper.raw_results.side_effect = Exception("API Error") + + result, raw = search_tool._run("test query") + + assert "API Error" in result + assert raw == {} + mock_api_wrapper.clean_results_with_images.assert_not_called() + + @pytest.mark.asyncio + @patch("builtins.print") + async def test_arun_success( + self, + mock_print, + search_tool, + mock_api_wrapper, + sample_raw_results, + sample_cleaned_results, + ): + """Test successful asynchronous run.""" + mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results) + mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results + + result, raw = await search_tool._arun("test query") + + assert result == sample_cleaned_results + assert raw == sample_raw_results + + mock_api_wrapper.raw_results_async.assert_called_once_with( + "test query", + search_tool.max_results, + search_tool.search_depth, + search_tool.include_domains, + search_tool.exclude_domains, + search_tool.include_answer, + search_tool.include_raw_content, + search_tool.include_images, + search_tool.include_image_descriptions, + ) + + mock_api_wrapper.clean_results_with_images.assert_called_once_with( + sample_raw_results + ) + mock_print.assert_called_once() + + @pytest.mark.asyncio + @patch("builtins.print") + async def test_arun_exception(self, mock_print, search_tool, mock_api_wrapper): + """Test asynchronous run with exception.""" + mock_api_wrapper.raw_results_async = AsyncMock( + side_effect=Exception("Async API Error") + ) + + result, raw = await search_tool._arun("test query") + + assert "Async API Error" in result + assert raw == {} + mock_api_wrapper.clean_results_with_images.assert_not_called() + + @patch("builtins.print") + def test_run_with_run_manager( + self, + mock_print, + search_tool, + mock_api_wrapper, + sample_raw_results, + sample_cleaned_results, + ): + """Test run with callback manager.""" + mock_run_manager = Mock() + mock_api_wrapper.raw_results.return_value = sample_raw_results + mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results + + result, raw = search_tool._run("test query", run_manager=mock_run_manager) + + assert result == sample_cleaned_results + assert raw == sample_raw_results + + @pytest.mark.asyncio + @patch("builtins.print") + async def test_arun_with_run_manager( + self, + mock_print, + search_tool, + mock_api_wrapper, + sample_raw_results, + sample_cleaned_results, + ): + """Test async run with callback manager.""" + mock_run_manager = Mock() + mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results) + mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results + + result, raw = await search_tool._arun( + "test query", run_manager=mock_run_manager + ) + + assert result == sample_cleaned_results + assert raw == sample_raw_results + + @patch("builtins.print") + def test_print_output_format( + self, + mock_print, + search_tool, + mock_api_wrapper, + sample_raw_results, + sample_cleaned_results, + ): + """Test that print outputs correctly formatted JSON.""" + mock_api_wrapper.raw_results.return_value = sample_raw_results + mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results + + search_tool._run("test query") + + # Verify print was called with expected format + call_args = mock_print.call_args[0] + assert call_args[0] == "sync" + assert isinstance(call_args[1], str) # Should be JSON string + + # Verify it's valid JSON + json_data = json.loads(call_args[1]) + assert json_data == sample_cleaned_results + + @pytest.mark.asyncio + @patch("builtins.print") + async def test_async_print_output_format( + self, + mock_print, + search_tool, + mock_api_wrapper, + sample_raw_results, + sample_cleaned_results, + ): + """Test that async print outputs correctly formatted JSON.""" + mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results) + mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results + + await search_tool._arun("test query") + + # Verify print was called with expected format + call_args = mock_print.call_args[0] + assert call_args[0] == "async" + assert isinstance(call_args[1], str) # Should be JSON string + + # Verify it's valid JSON + json_data = json.loads(call_args[1]) + assert json_data == sample_cleaned_results diff --git a/tests/unit/tools/test_tools_retriever.py b/tests/unit/tools/test_tools_retriever.py new file mode 100644 index 0000000..fa73b68 --- /dev/null +++ b/tests/unit/tools/test_tools_retriever.py @@ -0,0 +1,122 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +from unittest.mock import Mock, patch, MagicMock +from langchain_core.callbacks import ( + CallbackManagerForToolRun, + AsyncCallbackManagerForToolRun, +) +import pytest +from src.tools.retriever import RetrieverInput, RetrieverTool, get_retriever_tool +from src.rag import Document, Retriever, Resource, Chunk + + +def test_retriever_input_model(): + input_data = RetrieverInput(keywords="test keywords") + assert input_data.keywords == "test keywords" + + +def test_retriever_tool_init(): + mock_retriever = Mock(spec=Retriever) + resources = [Resource(uri="test://uri", title="Test")] + tool = RetrieverTool(retriever=mock_retriever, resources=resources) + + assert tool.name == "local_search_tool" + assert "retrieving information" in tool.description + assert tool.args_schema == RetrieverInput + assert tool.retriever == mock_retriever + assert tool.resources == resources + + +def test_retriever_tool_run_with_results(): + mock_retriever = Mock(spec=Retriever) + chunk = Chunk(content="test content", similarity=0.9) + doc = Document(id="doc1", chunks=[chunk]) + mock_retriever.query_relevant_documents.return_value = [doc] + + resources = [Resource(uri="test://uri", title="Test")] + tool = RetrieverTool(retriever=mock_retriever, resources=resources) + + result = tool._run("test keywords") + + mock_retriever.query_relevant_documents.assert_called_once_with( + "test keywords", resources + ) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] == doc.to_dict() + + +def test_retriever_tool_run_no_results(): + mock_retriever = Mock(spec=Retriever) + mock_retriever.query_relevant_documents.return_value = [] + + resources = [Resource(uri="test://uri", title="Test")] + tool = RetrieverTool(retriever=mock_retriever, resources=resources) + + result = tool._run("test keywords") + + assert result == "No results found from the local knowledge base." + + +@pytest.mark.asyncio +async def test_retriever_tool_arun(): + mock_retriever = Mock(spec=Retriever) + chunk = Chunk(content="async content", similarity=0.8) + doc = Document(id="doc2", chunks=[chunk]) + mock_retriever.query_relevant_documents.return_value = [doc] + + resources = [Resource(uri="test://uri", title="Test")] + tool = RetrieverTool(retriever=mock_retriever, resources=resources) + + mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun) + mock_sync_manager = Mock(spec=CallbackManagerForToolRun) + mock_run_manager.get_sync.return_value = mock_sync_manager + + result = await tool._arun("async keywords", mock_run_manager) + + mock_run_manager.get_sync.assert_called_once() + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] == doc.to_dict() + + +@patch("src.tools.retriever.build_retriever") +def test_get_retriever_tool_success(mock_build_retriever): + mock_retriever = Mock(spec=Retriever) + mock_build_retriever.return_value = mock_retriever + + resources = [Resource(uri="test://uri", title="Test")] + tool = get_retriever_tool(resources) + + assert isinstance(tool, RetrieverTool) + assert tool.retriever == mock_retriever + assert tool.resources == resources + + +def test_get_retriever_tool_empty_resources(): + result = get_retriever_tool([]) + assert result is None + + +@patch("src.tools.retriever.build_retriever") +def test_get_retriever_tool_no_retriever(mock_build_retriever): + mock_build_retriever.return_value = None + + resources = [Resource(uri="test://uri", title="Test")] + result = get_retriever_tool(resources) + + assert result is None + + +def test_retriever_tool_run_with_callback_manager(): + mock_retriever = Mock(spec=Retriever) + mock_retriever.query_relevant_documents.return_value = [] + + resources = [Resource(uri="test://uri", title="Test")] + tool = RetrieverTool(retriever=mock_retriever, resources=resources) + + mock_callback_manager = Mock(spec=CallbackManagerForToolRun) + result = tool._run("test keywords", mock_callback_manager) + + assert result == "No results found from the local knowledge base."