test: add more unit tests of tools (#315)

* test: add more test on test_tts.py

* test: add unit test of search and retriever in tools

* test: remove the main code of search.py

* test: add the travily_search unit test

* reformate the codes

* test: add unit tests of tools

* Added the pytest-asyncio dependency

* added the license header of test_tavily_search_api_wrapper.py
This commit is contained in:
Willem Jiang
2025-06-12 20:43:32 +08:00
committed by GitHub
parent bb7dc6e98c
commit 4c2fe2e7f5
14 changed files with 1057 additions and 35 deletions

View File

@@ -42,6 +42,7 @@ dev = [
test = [
"pytest>=7.4.0",
"pytest-cov>=4.1.0",
"pytest-asyncio>=1.0.0",
]
[tool.pytest.ini_options]

View File

@@ -1,8 +1,8 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .retriever import Retriever, Document, Resource
from .retriever import Retriever, Document, Resource, Chunk
from .ragflow import RAGFlowProvider
from .builder import build_retriever
__all__ = [Retriever, Document, Resource, RAGFlowProvider, build_retriever]
__all__ = [Retriever, Document, Resource, RAGFlowProvider, Chunk, build_retriever]

View File

@@ -60,18 +60,3 @@ def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
if not retriever:
return None
return RetrieverTool(retriever=retriever, resources=resources)
if __name__ == "__main__":
resources = [
Resource(
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
title="西游记",
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
)
]
retriever_tool = get_retriever_tool(resources)
print(retriever_tool.name)
print(retriever_tool.description)
print(retriever_tool.args)
print(retriever_tool.invoke("三打白骨精"))

View File

@@ -36,7 +36,10 @@ def get_web_search_tool(max_search_results: int):
include_image_descriptions=True,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
return LoggedDuckDuckGoSearch(
name="web_search",
num_results=max_search_results,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
return LoggedBraveSearch(
name="web_search",
@@ -56,14 +59,3 @@ def get_web_search_tool(max_search_results: int):
)
else:
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
if __name__ == "__main__":
results = LoggedDuckDuckGoSearch(
name="web_search", max_results=3, output_format="list"
)
print(results.name)
print(results.description)
print(results.args)
# .invoke("cute panda")
# print(json.dumps(results, indent=2, ensure_ascii=False))

View File

@@ -1,3 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from typing import Dict, List, Optional
@@ -107,9 +111,3 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
}
clean_results.append(clean_result)
return clean_results
if __name__ == "__main__":
wrapper = EnhancedTavilySearchAPIWrapper()
results = wrapper.raw_results("cute panda", include_images=True)
print(json.dumps(results, indent=2, ensure_ascii=False))

View File

@@ -1,3 +1,6 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from typing import Dict, List, Optional, Tuple, Union

View File

@@ -229,3 +229,20 @@ class TestVolcengineTTS:
args, kwargs = mock_post.call_args
request_json = json.loads(args[1])
assert request_json["user"]["uid"] == str(mock_uuid_value)
@patch("src.tools.tts.requests.post")
def test_text_to_speech_request_exception(self, mock_post):
"""Test error handling when requests.post raises an exception."""
# Mock requests.post to raise an exception
mock_post.side_effect = Exception("Network error")
# Create TTS client
tts = VolcengineTTS(
appid="test_appid",
access_token="test_token",
)
# Call the method
result = tts.text_to_speech("Hello, world!")
# Verify the result
assert result["success"] is False
assert result["error"] == "Network error"
assert result["audio_data"] is None

View File

@@ -0,0 +1,110 @@
import pytest
from unittest.mock import Mock, patch
from src.tools.crawl import crawl_tool
class TestCrawlTool:
@patch("src.tools.crawl.Crawler")
def test_crawl_tool_success(self, mock_crawler_class):
# Arrange
mock_crawler = Mock()
mock_article = Mock()
mock_article.to_markdown.return_value = (
"# Test Article\nThis is test content." * 100
)
mock_crawler.crawl.return_value = mock_article
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool(url)
# Assert
assert isinstance(result, dict)
assert result["url"] == url
assert "crawled_content" in result
assert len(result["crawled_content"]) <= 1000
mock_crawler_class.assert_called_once()
mock_crawler.crawl.assert_called_once_with(url)
mock_article.to_markdown.assert_called_once()
@patch("src.tools.crawl.Crawler")
def test_crawl_tool_short_content(self, mock_crawler_class):
# Arrange
mock_crawler = Mock()
mock_article = Mock()
short_content = "Short content"
mock_article.to_markdown.return_value = short_content
mock_crawler.crawl.return_value = mock_article
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool(url)
# Assert
assert result["crawled_content"] == short_content
@patch("src.tools.crawl.Crawler")
@patch("src.tools.crawl.logger")
def test_crawl_tool_crawler_exception(self, mock_logger, mock_crawler_class):
# Arrange
mock_crawler = Mock()
mock_crawler.crawl.side_effect = Exception("Network error")
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool(url)
# Assert
assert isinstance(result, str)
assert "Failed to crawl" in result
assert "Network error" in result
mock_logger.error.assert_called_once()
@patch("src.tools.crawl.Crawler")
@patch("src.tools.crawl.logger")
def test_crawl_tool_crawler_instantiation_exception(
self, mock_logger, mock_crawler_class
):
# Arrange
mock_crawler_class.side_effect = Exception("Crawler init error")
url = "https://example.com"
# Act
result = crawl_tool(url)
# Assert
assert isinstance(result, str)
assert "Failed to crawl" in result
assert "Crawler init error" in result
mock_logger.error.assert_called_once()
@patch("src.tools.crawl.Crawler")
@patch("src.tools.crawl.logger")
def test_crawl_tool_markdown_conversion_exception(
self, mock_logger, mock_crawler_class
):
# Arrange
mock_crawler = Mock()
mock_article = Mock()
mock_article.to_markdown.side_effect = Exception("Markdown conversion error")
mock_crawler.crawl.return_value = mock_article
mock_crawler_class.return_value = mock_crawler
url = "https://example.com"
# Act
result = crawl_tool(url)
# Assert
assert isinstance(result, str)
assert "Failed to crawl" in result
assert "Markdown conversion error" in result
mock_logger.error.assert_called_once()

View File

@@ -0,0 +1,121 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
import logging
from unittest.mock import Mock, call, patch, MagicMock
from src.tools.decorators import LoggedToolMixin, create_logged_tool
class MockBaseTool:
"""Mock base tool class for testing."""
def _run(self, *args, **kwargs):
return "base_result"
class TestLoggedToolMixin:
def test_run_calls_log_operation(self):
"""Test that _run calls _log_operation with correct parameters."""
# Create a logged tool instance
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
# Mock the _log_operation method
tool._log_operation = Mock()
# Call _run with test parameters
args = ("arg1", "arg2")
kwargs = {"key1": "value1", "key2": "value2"}
tool._run(*args, **kwargs)
# Verify _log_operation was called with correct parameters
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
def test_run_calls_super_run(self):
"""Test that _run calls the parent class _run method."""
# Create a logged tool instance
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
# Mock the parent _run method
with patch.object(
MockBaseTool, "_run", return_value="mocked_result"
) as mock_super_run:
args = ("arg1", "arg2")
kwargs = {"key1": "value1"}
result = tool._run(*args, **kwargs)
# Verify super()._run was called with correct parameters
mock_super_run.assert_called_once_with(*args, **kwargs)
# Verify the result is returned
assert result == "mocked_result"
def test_run_logs_result(self):
"""Test that _run logs the result with debug level."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
with patch("src.tools.decorators.logger.debug") as mock_debug:
result = tool._run("test_arg")
# Verify debug log was called with correct message
mock_debug.assert_has_calls(
[
call("Tool MockBaseTool._run called with parameters: test_arg"),
call("Tool MockBaseTool returned: base_result"),
]
)
def test_run_returns_super_result(self):
"""Test that _run returns the result from parent class."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
result = tool._run()
assert result == "base_result"
def test_run_with_no_args(self):
"""Test _run method with no arguments."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
with patch("src.tools.decorators.logger.debug") as mock_debug:
tool._log_operation = Mock()
result = tool._run()
# Verify _log_operation called with no args
tool._log_operation.assert_called_once_with("_run")
# Verify result logging
mock_debug.assert_called_once()
assert result == "base_result"
def test_run_with_mixed_args_kwargs(self):
"""Test _run method with both positional and keyword arguments."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
tool._log_operation = Mock()
args = ("pos1", "pos2")
kwargs = {"kw1": "val1", "kw2": "val2"}
result = tool._run(*args, **kwargs)
# Verify all arguments passed correctly
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
assert result == "base_result"
def test_run_class_name_replacement(self):
"""Test that class name 'Logged' prefix is correctly removed in logging."""
LoggedTool = create_logged_tool(MockBaseTool)
tool = LoggedTool()
with patch("src.tools.decorators.logger.debug") as mock_debug:
tool._run()
# Verify the logged class name has 'Logged' prefix removed
call_args = mock_debug.call_args[0][0]
assert "Tool MockBaseTool returned:" in call_args
assert "LoggedMockBaseTool" not in call_args

View File

@@ -0,0 +1,147 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import pytest
from unittest.mock import Mock, patch, MagicMock
from src.tools.python_repl import python_repl_tool
class TestPythonReplTool:
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_successful_code_execution(self, mock_logger, mock_repl):
# Arrange
code = "print('Hello, World!')"
expected_output = "Hello, World!\n"
mock_repl.run.return_value = expected_output
# Act
result = python_repl_tool(code)
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.info.assert_called_with("Code execution successful")
assert "Successfully executed:" in result
assert code in result
assert expected_output in result
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_invalid_input_type(self, mock_logger, mock_repl):
# Arrange
invalid_code = 123
# Act & Assert - expect ValidationError from LangChain
with pytest.raises(Exception) as exc_info:
python_repl_tool(invalid_code)
# Verify that it's a validation error
assert "ValidationError" in str(
type(exc_info.value)
) or "validation error" in str(exc_info.value)
# The REPL should not be called since validation fails first
mock_repl.run.assert_not_called()
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_code_execution_with_error_in_result(self, mock_logger, mock_repl):
# Arrange
code = "invalid_function()"
error_result = "NameError: name 'invalid_function' is not defined"
mock_repl.run.return_value = error_result
# Act
result = python_repl_tool(code)
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.error.assert_called_with(error_result)
assert "Error executing code:" in result
assert code in result
assert error_result in result
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_code_execution_with_exception_in_result(self, mock_logger, mock_repl):
# Arrange
code = "1/0"
exception_result = "ZeroDivisionError: division by zero"
mock_repl.run.return_value = exception_result
# Act
result = python_repl_tool(code)
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.error.assert_called_with(exception_result)
assert "Error executing code:" in result
assert code in result
assert exception_result in result
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_code_execution_raises_exception(self, mock_logger, mock_repl):
# Arrange
code = "print('test')"
exception = RuntimeError("REPL failed")
mock_repl.run.side_effect = exception
# Act
result = python_repl_tool(code)
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.error.assert_called_with(repr(exception))
assert "Error executing code:" in result
assert code in result
assert repr(exception) in result
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_successful_execution_with_calculation(self, mock_logger, mock_repl):
# Arrange
code = "result = 2 + 3\nprint(result)"
expected_output = "5\n"
mock_repl.run.return_value = expected_output
# Act
result = python_repl_tool(code)
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.info.assert_any_call("Executing Python code")
mock_logger.info.assert_any_call("Code execution successful")
assert "Successfully executed:" in result
assert code in result
assert expected_output in result
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_empty_string_code(self, mock_logger, mock_repl):
# Arrange
code = ""
mock_repl.run.return_value = ""
# Act
result = python_repl_tool(code)
# Assert
mock_repl.run.assert_called_once_with(code)
mock_logger.info.assert_called_with("Code execution successful")
assert "Successfully executed:" in result
@patch("src.tools.python_repl.repl")
@patch("src.tools.python_repl.logger")
def test_logging_calls(self, mock_logger, mock_repl):
# Arrange
code = "x = 1"
mock_repl.run.return_value = ""
# Act
python_repl_tool(code)
# Assert
mock_logger.info.assert_any_call("Executing Python code")
mock_logger.info.assert_any_call("Code execution successful")

View File

@@ -0,0 +1,54 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import os
import pytest
from unittest.mock import patch, MagicMock
from src.tools.search import get_web_search_tool
from src.config import SearchEngine
class TestGetWebSearchTool:
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
def test_get_web_search_tool_tavily(self):
tool = get_web_search_tool(max_search_results=5)
assert tool.name == "web_search"
assert tool.max_results == 5
assert tool.include_raw_content is True
assert tool.include_images is True
assert tool.include_image_descriptions is True
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.DUCKDUCKGO.value)
def test_get_web_search_tool_duckduckgo(self):
tool = get_web_search_tool(max_search_results=3)
assert tool.name == "web_search"
assert tool.max_results == 3
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
@patch.dict(os.environ, {"BRAVE_SEARCH_API_KEY": "test_api_key"})
def test_get_web_search_tool_brave(self):
tool = get_web_search_tool(max_search_results=4)
assert tool.name == "web_search"
assert tool.search_wrapper.api_key == "test_api_key"
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.ARXIV.value)
def test_get_web_search_tool_arxiv(self):
tool = get_web_search_tool(max_search_results=2)
assert tool.name == "web_search"
assert tool.api_wrapper.top_k_results == 2
assert tool.api_wrapper.load_max_docs == 2
assert tool.api_wrapper.load_all_available_meta is True
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", "unsupported_engine")
def test_get_web_search_tool_unsupported_engine(self):
with pytest.raises(
ValueError, match="Unsupported search engine: unsupported_engine"
):
get_web_search_tool(max_search_results=1)
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
@patch.dict(os.environ, {}, clear=True)
def test_get_web_search_tool_brave_no_api_key(self):
tool = get_web_search_tool(max_search_results=1)
assert tool.search_wrapper.api_key == ""

View File

@@ -0,0 +1,207 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
import pytest
from unittest.mock import Mock, patch, AsyncMock, MagicMock
import aiohttp
import requests
from src.tools.tavily_search.tavily_search_api_wrapper import (
EnhancedTavilySearchAPIWrapper,
)
class TestEnhancedTavilySearchAPIWrapper:
@pytest.fixture
def wrapper(self):
with patch(
"src.tools.tavily_search.tavily_search_api_wrapper.OriginalTavilySearchAPIWrapper"
):
wrapper = EnhancedTavilySearchAPIWrapper(tavily_api_key="dummy-key")
# The parent class is mocked, so initialization won't fail
return wrapper
@pytest.fixture
def mock_response_data(self):
return {
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.9,
"raw_content": "Raw test content",
}
],
"images": [
{
"url": "https://example.com/image.jpg",
"description": "Test image description",
}
],
}
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
def test_raw_results_success(self, mock_post, wrapper, mock_response_data):
mock_response = Mock()
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = wrapper.raw_results("test query", max_results=10)
assert result == mock_response_data
mock_post.assert_called_once()
call_args = mock_post.call_args
assert "json" in call_args.kwargs
assert call_args.kwargs["json"]["query"] == "test query"
assert call_args.kwargs["json"]["max_results"] == 10
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
def test_raw_results_with_all_parameters(
self, mock_post, wrapper, mock_response_data
):
mock_response = Mock()
mock_response.json.return_value = mock_response_data
mock_response.raise_for_status.return_value = None
mock_post.return_value = mock_response
result = wrapper.raw_results(
"test query",
max_results=3,
search_depth="basic",
include_domains=["example.com"],
exclude_domains=["spam.com"],
include_answer=True,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
assert result == mock_response_data
call_args = mock_post.call_args
params = call_args.kwargs["json"]
assert params["include_domains"] == ["example.com"]
assert params["exclude_domains"] == ["spam.com"]
assert params["include_answer"] is True
assert params["include_raw_content"] is True
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
def test_raw_results_http_error(self, mock_post, wrapper):
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")
mock_post.return_value = mock_response
with pytest.raises(requests.HTTPError):
wrapper.raw_results("test query")
@pytest.mark.asyncio
async def test_raw_results_async_success(self, wrapper, mock_response_data):
# Create a mock that acts as both the response and its context manager
mock_response_cm = AsyncMock()
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
mock_response_cm.status = 200
mock_response_cm.text = AsyncMock(return_value=json.dumps(mock_response_data))
# Create mock session that returns the context manager
mock_session = AsyncMock()
mock_session.post = MagicMock(
return_value=mock_response_cm
) # Use MagicMock, not AsyncMock
# Create mock session class
mock_session_cm = AsyncMock()
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
with patch(
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
return_value=mock_session_cm,
):
result = await wrapper.raw_results_async("test query")
assert result == mock_response_data
@pytest.mark.asyncio
async def test_raw_results_async_error(self, wrapper):
# Create a mock that acts as both the response and its context manager
mock_response_cm = AsyncMock()
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
mock_response_cm.status = 400
mock_response_cm.reason = "Bad Request"
# Create mock session that returns the context manager
mock_session = AsyncMock()
mock_session.post = MagicMock(
return_value=mock_response_cm
) # Use MagicMock, not AsyncMock
# Create mock session class
mock_session_cm = AsyncMock()
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
with patch(
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
return_value=mock_session_cm,
):
with pytest.raises(Exception, match="Error 400: Bad Request"):
await wrapper.raw_results_async("test query")
def test_clean_results_with_images(self, wrapper, mock_response_data):
result = wrapper.clean_results_with_images(mock_response_data)
assert len(result) == 2
# Test page result
page_result = result[0]
assert page_result["type"] == "page"
assert page_result["title"] == "Test Title"
assert page_result["url"] == "https://example.com"
assert page_result["content"] == "Test content"
assert page_result["score"] == 0.9
assert page_result["raw_content"] == "Raw test content"
# Test image result
image_result = result[1]
assert image_result["type"] == "image"
assert image_result["image_url"] == "https://example.com/image.jpg"
assert image_result["image_description"] == "Test image description"
def test_clean_results_without_raw_content(self, wrapper):
data = {
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.9,
}
],
"images": [],
}
result = wrapper.clean_results_with_images(data)
assert len(result) == 1
assert "raw_content" not in result[0]
def test_clean_results_empty_images(self, wrapper):
data = {
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.9,
}
],
"images": [],
}
result = wrapper.clean_results_with_images(data)
assert len(result) == 1
assert result[0]["type"] == "page"

View File

@@ -0,0 +1,265 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
import pytest
from unittest.mock import Mock, patch, AsyncMock
from typing import Dict, Any
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
from src.tools.tavily_search.tavily_search_api_wrapper import (
EnhancedTavilySearchAPIWrapper,
)
class TestTavilySearchResultsWithImages:
@pytest.fixture
def mock_api_wrapper(self):
"""Create a mock API wrapper."""
wrapper = Mock(spec=EnhancedTavilySearchAPIWrapper)
return wrapper
@pytest.fixture
def search_tool(self, mock_api_wrapper):
"""Create a TavilySearchResultsWithImages instance with mocked dependencies."""
tool = TavilySearchResultsWithImages(
max_results=5,
include_answer=True,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
tool.api_wrapper = mock_api_wrapper
return tool
@pytest.fixture
def sample_raw_results(self):
"""Sample raw results from Tavily API."""
return {
"query": "test query",
"answer": "Test answer",
"images": ["https://example.com/image1.jpg"],
"results": [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
"score": 0.95,
"raw_content": "Raw test content",
}
],
"response_time": 1.5,
}
@pytest.fixture
def sample_cleaned_results(self):
"""Sample cleaned results."""
return [
{
"title": "Test Title",
"url": "https://example.com",
"content": "Test content",
}
]
def test_init_default_values(self):
"""Test initialization with default values."""
tool = TavilySearchResultsWithImages()
assert tool.include_image_descriptions is False
assert isinstance(tool.api_wrapper, EnhancedTavilySearchAPIWrapper)
def test_init_custom_values(self):
"""Test initialization with custom values."""
tool = TavilySearchResultsWithImages(
max_results=10, include_image_descriptions=True
)
assert tool.max_results == 10
assert tool.include_image_descriptions is True
@patch("builtins.print")
def test_run_success(
self,
mock_print,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test successful synchronous run."""
mock_api_wrapper.raw_results.return_value = sample_raw_results
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = search_tool._run("test query")
assert result == sample_cleaned_results
assert raw == sample_raw_results
mock_api_wrapper.raw_results.assert_called_once_with(
"test query",
search_tool.max_results,
search_tool.search_depth,
search_tool.include_domains,
search_tool.exclude_domains,
search_tool.include_answer,
search_tool.include_raw_content,
search_tool.include_images,
search_tool.include_image_descriptions,
)
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
sample_raw_results
)
mock_print.assert_called_once()
@patch("builtins.print")
def test_run_exception(self, mock_print, search_tool, mock_api_wrapper):
"""Test synchronous run with exception."""
mock_api_wrapper.raw_results.side_effect = Exception("API Error")
result, raw = search_tool._run("test query")
assert "API Error" in result
assert raw == {}
mock_api_wrapper.clean_results_with_images.assert_not_called()
@pytest.mark.asyncio
@patch("builtins.print")
async def test_arun_success(
self,
mock_print,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test successful asynchronous run."""
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = await search_tool._arun("test query")
assert result == sample_cleaned_results
assert raw == sample_raw_results
mock_api_wrapper.raw_results_async.assert_called_once_with(
"test query",
search_tool.max_results,
search_tool.search_depth,
search_tool.include_domains,
search_tool.exclude_domains,
search_tool.include_answer,
search_tool.include_raw_content,
search_tool.include_images,
search_tool.include_image_descriptions,
)
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
sample_raw_results
)
mock_print.assert_called_once()
@pytest.mark.asyncio
@patch("builtins.print")
async def test_arun_exception(self, mock_print, search_tool, mock_api_wrapper):
"""Test asynchronous run with exception."""
mock_api_wrapper.raw_results_async = AsyncMock(
side_effect=Exception("Async API Error")
)
result, raw = await search_tool._arun("test query")
assert "Async API Error" in result
assert raw == {}
mock_api_wrapper.clean_results_with_images.assert_not_called()
@patch("builtins.print")
def test_run_with_run_manager(
self,
mock_print,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test run with callback manager."""
mock_run_manager = Mock()
mock_api_wrapper.raw_results.return_value = sample_raw_results
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = search_tool._run("test query", run_manager=mock_run_manager)
assert result == sample_cleaned_results
assert raw == sample_raw_results
@pytest.mark.asyncio
@patch("builtins.print")
async def test_arun_with_run_manager(
self,
mock_print,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test async run with callback manager."""
mock_run_manager = Mock()
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
result, raw = await search_tool._arun(
"test query", run_manager=mock_run_manager
)
assert result == sample_cleaned_results
assert raw == sample_raw_results
@patch("builtins.print")
def test_print_output_format(
self,
mock_print,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test that print outputs correctly formatted JSON."""
mock_api_wrapper.raw_results.return_value = sample_raw_results
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
search_tool._run("test query")
# Verify print was called with expected format
call_args = mock_print.call_args[0]
assert call_args[0] == "sync"
assert isinstance(call_args[1], str) # Should be JSON string
# Verify it's valid JSON
json_data = json.loads(call_args[1])
assert json_data == sample_cleaned_results
@pytest.mark.asyncio
@patch("builtins.print")
async def test_async_print_output_format(
self,
mock_print,
search_tool,
mock_api_wrapper,
sample_raw_results,
sample_cleaned_results,
):
"""Test that async print outputs correctly formatted JSON."""
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
await search_tool._arun("test query")
# Verify print was called with expected format
call_args = mock_print.call_args[0]
assert call_args[0] == "async"
assert isinstance(call_args[1], str) # Should be JSON string
# Verify it's valid JSON
json_data = json.loads(call_args[1])
assert json_data == sample_cleaned_results

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from unittest.mock import Mock, patch, MagicMock
from langchain_core.callbacks import (
CallbackManagerForToolRun,
AsyncCallbackManagerForToolRun,
)
import pytest
from src.tools.retriever import RetrieverInput, RetrieverTool, get_retriever_tool
from src.rag import Document, Retriever, Resource, Chunk
def test_retriever_input_model():
input_data = RetrieverInput(keywords="test keywords")
assert input_data.keywords == "test keywords"
def test_retriever_tool_init():
mock_retriever = Mock(spec=Retriever)
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
assert tool.name == "local_search_tool"
assert "retrieving information" in tool.description
assert tool.args_schema == RetrieverInput
assert tool.retriever == mock_retriever
assert tool.resources == resources
def test_retriever_tool_run_with_results():
mock_retriever = Mock(spec=Retriever)
chunk = Chunk(content="test content", similarity=0.9)
doc = Document(id="doc1", chunks=[chunk])
mock_retriever.query_relevant_documents.return_value = [doc]
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
result = tool._run("test keywords")
mock_retriever.query_relevant_documents.assert_called_once_with(
"test keywords", resources
)
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == doc.to_dict()
def test_retriever_tool_run_no_results():
mock_retriever = Mock(spec=Retriever)
mock_retriever.query_relevant_documents.return_value = []
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
result = tool._run("test keywords")
assert result == "No results found from the local knowledge base."
@pytest.mark.asyncio
async def test_retriever_tool_arun():
mock_retriever = Mock(spec=Retriever)
chunk = Chunk(content="async content", similarity=0.8)
doc = Document(id="doc2", chunks=[chunk])
mock_retriever.query_relevant_documents.return_value = [doc]
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
mock_sync_manager = Mock(spec=CallbackManagerForToolRun)
mock_run_manager.get_sync.return_value = mock_sync_manager
result = await tool._arun("async keywords", mock_run_manager)
mock_run_manager.get_sync.assert_called_once()
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == doc.to_dict()
@patch("src.tools.retriever.build_retriever")
def test_get_retriever_tool_success(mock_build_retriever):
mock_retriever = Mock(spec=Retriever)
mock_build_retriever.return_value = mock_retriever
resources = [Resource(uri="test://uri", title="Test")]
tool = get_retriever_tool(resources)
assert isinstance(tool, RetrieverTool)
assert tool.retriever == mock_retriever
assert tool.resources == resources
def test_get_retriever_tool_empty_resources():
result = get_retriever_tool([])
assert result is None
@patch("src.tools.retriever.build_retriever")
def test_get_retriever_tool_no_retriever(mock_build_retriever):
mock_build_retriever.return_value = None
resources = [Resource(uri="test://uri", title="Test")]
result = get_retriever_tool(resources)
assert result is None
def test_retriever_tool_run_with_callback_manager():
mock_retriever = Mock(spec=Retriever)
mock_retriever.query_relevant_documents.return_value = []
resources = [Resource(uri="test://uri", title="Test")]
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
mock_callback_manager = Mock(spec=CallbackManagerForToolRun)
result = tool._run("test keywords", mock_callback_manager)
assert result == "No results found from the local knowledge base."