mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
test: add more unit tests of tools (#315)
* test: add more test on test_tts.py * test: add unit test of search and retriever in tools * test: remove the main code of search.py * test: add the travily_search unit test * reformate the codes * test: add unit tests of tools * Added the pytest-asyncio dependency * added the license header of test_tavily_search_api_wrapper.py
This commit is contained in:
@@ -42,6 +42,7 @@ dev = [
|
||||
test = [
|
||||
"pytest>=7.4.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .retriever import Retriever, Document, Resource
|
||||
from .retriever import Retriever, Document, Resource, Chunk
|
||||
from .ragflow import RAGFlowProvider
|
||||
from .builder import build_retriever
|
||||
|
||||
__all__ = [Retriever, Document, Resource, RAGFlowProvider, build_retriever]
|
||||
__all__ = [Retriever, Document, Resource, RAGFlowProvider, Chunk, build_retriever]
|
||||
|
||||
@@ -60,18 +60,3 @@ def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
||||
if not retriever:
|
||||
return None
|
||||
return RetrieverTool(retriever=retriever, resources=resources)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resources = [
|
||||
Resource(
|
||||
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
|
||||
title="西游记",
|
||||
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
|
||||
)
|
||||
]
|
||||
retriever_tool = get_retriever_tool(resources)
|
||||
print(retriever_tool.name)
|
||||
print(retriever_tool.description)
|
||||
print(retriever_tool.args)
|
||||
print(retriever_tool.invoke("三打白骨精"))
|
||||
|
||||
@@ -36,7 +36,10 @@ def get_web_search_tool(max_search_results: int):
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
|
||||
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
|
||||
return LoggedDuckDuckGoSearch(
|
||||
name="web_search",
|
||||
num_results=max_search_results,
|
||||
)
|
||||
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
|
||||
return LoggedBraveSearch(
|
||||
name="web_search",
|
||||
@@ -56,14 +59,3 @@ def get_web_search_tool(max_search_results: int):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = LoggedDuckDuckGoSearch(
|
||||
name="web_search", max_results=3, output_format="list"
|
||||
)
|
||||
print(results.name)
|
||||
print(results.description)
|
||||
print(results.args)
|
||||
# .invoke("cute panda")
|
||||
# print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -107,9 +111,3 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
|
||||
}
|
||||
clean_results.append(clean_result)
|
||||
return clean_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wrapper = EnhancedTavilySearchAPIWrapper()
|
||||
results = wrapper.raw_results("cute panda", include_images=True)
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
@@ -229,3 +229,20 @@ class TestVolcengineTTS:
|
||||
args, kwargs = mock_post.call_args
|
||||
request_json = json.loads(args[1])
|
||||
assert request_json["user"]["uid"] == str(mock_uuid_value)
|
||||
|
||||
@patch("src.tools.tts.requests.post")
|
||||
def test_text_to_speech_request_exception(self, mock_post):
|
||||
"""Test error handling when requests.post raises an exception."""
|
||||
# Mock requests.post to raise an exception
|
||||
mock_post.side_effect = Exception("Network error")
|
||||
# Create TTS client
|
||||
tts = VolcengineTTS(
|
||||
appid="test_appid",
|
||||
access_token="test_token",
|
||||
)
|
||||
# Call the method
|
||||
result = tts.text_to_speech("Hello, world!")
|
||||
# Verify the result
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "Network error"
|
||||
assert result["audio_data"] is None
|
||||
|
||||
110
tests/unit/tools/test_crawl.py
Normal file
110
tests/unit/tools/test_crawl.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from src.tools.crawl import crawl_tool
|
||||
|
||||
|
||||
class TestCrawlTool:
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
def test_crawl_tool_success(self, mock_crawler_class):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_article = Mock()
|
||||
mock_article.to_markdown.return_value = (
|
||||
"# Test Article\nThis is test content." * 100
|
||||
)
|
||||
mock_crawler.crawl.return_value = mock_article
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool(url)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, dict)
|
||||
assert result["url"] == url
|
||||
assert "crawled_content" in result
|
||||
assert len(result["crawled_content"]) <= 1000
|
||||
mock_crawler_class.assert_called_once()
|
||||
mock_crawler.crawl.assert_called_once_with(url)
|
||||
mock_article.to_markdown.assert_called_once()
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
def test_crawl_tool_short_content(self, mock_crawler_class):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_article = Mock()
|
||||
short_content = "Short content"
|
||||
mock_article.to_markdown.return_value = short_content
|
||||
mock_crawler.crawl.return_value = mock_article
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool(url)
|
||||
|
||||
# Assert
|
||||
assert result["crawled_content"] == short_content
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
@patch("src.tools.crawl.logger")
|
||||
def test_crawl_tool_crawler_exception(self, mock_logger, mock_crawler_class):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_crawler.crawl.side_effect = Exception("Network error")
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool(url)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert "Failed to crawl" in result
|
||||
assert "Network error" in result
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
@patch("src.tools.crawl.logger")
|
||||
def test_crawl_tool_crawler_instantiation_exception(
|
||||
self, mock_logger, mock_crawler_class
|
||||
):
|
||||
# Arrange
|
||||
mock_crawler_class.side_effect = Exception("Crawler init error")
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool(url)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert "Failed to crawl" in result
|
||||
assert "Crawler init error" in result
|
||||
mock_logger.error.assert_called_once()
|
||||
|
||||
@patch("src.tools.crawl.Crawler")
|
||||
@patch("src.tools.crawl.logger")
|
||||
def test_crawl_tool_markdown_conversion_exception(
|
||||
self, mock_logger, mock_crawler_class
|
||||
):
|
||||
# Arrange
|
||||
mock_crawler = Mock()
|
||||
mock_article = Mock()
|
||||
mock_article.to_markdown.side_effect = Exception("Markdown conversion error")
|
||||
mock_crawler.crawl.return_value = mock_article
|
||||
mock_crawler_class.return_value = mock_crawler
|
||||
|
||||
url = "https://example.com"
|
||||
|
||||
# Act
|
||||
result = crawl_tool(url)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, str)
|
||||
assert "Failed to crawl" in result
|
||||
assert "Markdown conversion error" in result
|
||||
mock_logger.error.assert_called_once()
|
||||
121
tests/unit/tools/test_decorators.py
Normal file
121
tests/unit/tools/test_decorators.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
from unittest.mock import Mock, call, patch, MagicMock
|
||||
from src.tools.decorators import LoggedToolMixin, create_logged_tool
|
||||
|
||||
|
||||
class MockBaseTool:
|
||||
"""Mock base tool class for testing."""
|
||||
|
||||
def _run(self, *args, **kwargs):
|
||||
return "base_result"
|
||||
|
||||
|
||||
class TestLoggedToolMixin:
|
||||
|
||||
def test_run_calls_log_operation(self):
|
||||
"""Test that _run calls _log_operation with correct parameters."""
|
||||
# Create a logged tool instance
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
# Mock the _log_operation method
|
||||
tool._log_operation = Mock()
|
||||
|
||||
# Call _run with test parameters
|
||||
args = ("arg1", "arg2")
|
||||
kwargs = {"key1": "value1", "key2": "value2"}
|
||||
tool._run(*args, **kwargs)
|
||||
|
||||
# Verify _log_operation was called with correct parameters
|
||||
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
|
||||
|
||||
def test_run_calls_super_run(self):
|
||||
"""Test that _run calls the parent class _run method."""
|
||||
# Create a logged tool instance
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
# Mock the parent _run method
|
||||
with patch.object(
|
||||
MockBaseTool, "_run", return_value="mocked_result"
|
||||
) as mock_super_run:
|
||||
args = ("arg1", "arg2")
|
||||
kwargs = {"key1": "value1"}
|
||||
result = tool._run(*args, **kwargs)
|
||||
|
||||
# Verify super()._run was called with correct parameters
|
||||
mock_super_run.assert_called_once_with(*args, **kwargs)
|
||||
# Verify the result is returned
|
||||
assert result == "mocked_result"
|
||||
|
||||
def test_run_logs_result(self):
|
||||
"""Test that _run logs the result with debug level."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
with patch("src.tools.decorators.logger.debug") as mock_debug:
|
||||
result = tool._run("test_arg")
|
||||
|
||||
# Verify debug log was called with correct message
|
||||
mock_debug.assert_has_calls(
|
||||
[
|
||||
call("Tool MockBaseTool._run called with parameters: test_arg"),
|
||||
call("Tool MockBaseTool returned: base_result"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_run_returns_super_result(self):
|
||||
"""Test that _run returns the result from parent class."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
result = tool._run()
|
||||
assert result == "base_result"
|
||||
|
||||
def test_run_with_no_args(self):
|
||||
"""Test _run method with no arguments."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
with patch("src.tools.decorators.logger.debug") as mock_debug:
|
||||
tool._log_operation = Mock()
|
||||
|
||||
result = tool._run()
|
||||
|
||||
# Verify _log_operation called with no args
|
||||
tool._log_operation.assert_called_once_with("_run")
|
||||
# Verify result logging
|
||||
mock_debug.assert_called_once()
|
||||
assert result == "base_result"
|
||||
|
||||
def test_run_with_mixed_args_kwargs(self):
|
||||
"""Test _run method with both positional and keyword arguments."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
tool._log_operation = Mock()
|
||||
|
||||
args = ("pos1", "pos2")
|
||||
kwargs = {"kw1": "val1", "kw2": "val2"}
|
||||
result = tool._run(*args, **kwargs)
|
||||
|
||||
# Verify all arguments passed correctly
|
||||
tool._log_operation.assert_called_once_with("_run", *args, **kwargs)
|
||||
assert result == "base_result"
|
||||
|
||||
def test_run_class_name_replacement(self):
|
||||
"""Test that class name 'Logged' prefix is correctly removed in logging."""
|
||||
LoggedTool = create_logged_tool(MockBaseTool)
|
||||
tool = LoggedTool()
|
||||
|
||||
with patch("src.tools.decorators.logger.debug") as mock_debug:
|
||||
tool._run()
|
||||
|
||||
# Verify the logged class name has 'Logged' prefix removed
|
||||
call_args = mock_debug.call_args[0][0]
|
||||
assert "Tool MockBaseTool returned:" in call_args
|
||||
assert "LoggedMockBaseTool" not in call_args
|
||||
147
tests/unit/tools/test_python_repl.py
Normal file
147
tests/unit/tools/test_python_repl.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from src.tools.python_repl import python_repl_tool
|
||||
|
||||
|
||||
class TestPythonReplTool:
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_successful_code_execution(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "print('Hello, World!')"
|
||||
expected_output = "Hello, World!\n"
|
||||
mock_repl.run.return_value = expected_output
|
||||
|
||||
# Act
|
||||
result = python_repl_tool(code)
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.info.assert_called_with("Code execution successful")
|
||||
assert "Successfully executed:" in result
|
||||
assert code in result
|
||||
assert expected_output in result
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_invalid_input_type(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
invalid_code = 123
|
||||
|
||||
# Act & Assert - expect ValidationError from LangChain
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
python_repl_tool(invalid_code)
|
||||
|
||||
# Verify that it's a validation error
|
||||
assert "ValidationError" in str(
|
||||
type(exc_info.value)
|
||||
) or "validation error" in str(exc_info.value)
|
||||
|
||||
# The REPL should not be called since validation fails first
|
||||
mock_repl.run.assert_not_called()
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_code_execution_with_error_in_result(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "invalid_function()"
|
||||
error_result = "NameError: name 'invalid_function' is not defined"
|
||||
mock_repl.run.return_value = error_result
|
||||
|
||||
# Act
|
||||
result = python_repl_tool(code)
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.error.assert_called_with(error_result)
|
||||
assert "Error executing code:" in result
|
||||
assert code in result
|
||||
assert error_result in result
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_code_execution_with_exception_in_result(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "1/0"
|
||||
exception_result = "ZeroDivisionError: division by zero"
|
||||
mock_repl.run.return_value = exception_result
|
||||
|
||||
# Act
|
||||
result = python_repl_tool(code)
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.error.assert_called_with(exception_result)
|
||||
assert "Error executing code:" in result
|
||||
assert code in result
|
||||
assert exception_result in result
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_code_execution_raises_exception(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "print('test')"
|
||||
exception = RuntimeError("REPL failed")
|
||||
mock_repl.run.side_effect = exception
|
||||
|
||||
# Act
|
||||
result = python_repl_tool(code)
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.error.assert_called_with(repr(exception))
|
||||
assert "Error executing code:" in result
|
||||
assert code in result
|
||||
assert repr(exception) in result
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_successful_execution_with_calculation(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "result = 2 + 3\nprint(result)"
|
||||
expected_output = "5\n"
|
||||
mock_repl.run.return_value = expected_output
|
||||
|
||||
# Act
|
||||
result = python_repl_tool(code)
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.info.assert_any_call("Executing Python code")
|
||||
mock_logger.info.assert_any_call("Code execution successful")
|
||||
assert "Successfully executed:" in result
|
||||
assert code in result
|
||||
assert expected_output in result
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_empty_string_code(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = ""
|
||||
mock_repl.run.return_value = ""
|
||||
|
||||
# Act
|
||||
result = python_repl_tool(code)
|
||||
|
||||
# Assert
|
||||
mock_repl.run.assert_called_once_with(code)
|
||||
mock_logger.info.assert_called_with("Code execution successful")
|
||||
assert "Successfully executed:" in result
|
||||
|
||||
@patch("src.tools.python_repl.repl")
|
||||
@patch("src.tools.python_repl.logger")
|
||||
def test_logging_calls(self, mock_logger, mock_repl):
|
||||
# Arrange
|
||||
code = "x = 1"
|
||||
mock_repl.run.return_value = ""
|
||||
|
||||
# Act
|
||||
python_repl_tool(code)
|
||||
|
||||
# Assert
|
||||
mock_logger.info.assert_any_call("Executing Python code")
|
||||
mock_logger.info.assert_any_call("Code execution successful")
|
||||
54
tests/unit/tools/test_search.py
Normal file
54
tests/unit/tools/test_search.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src.tools.search import get_web_search_tool
|
||||
from src.config import SearchEngine
|
||||
|
||||
|
||||
class TestGetWebSearchTool:
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||
def test_get_web_search_tool_tavily(self):
|
||||
tool = get_web_search_tool(max_search_results=5)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 5
|
||||
assert tool.include_raw_content is True
|
||||
assert tool.include_images is True
|
||||
assert tool.include_image_descriptions is True
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.DUCKDUCKGO.value)
|
||||
def test_get_web_search_tool_duckduckgo(self):
|
||||
tool = get_web_search_tool(max_search_results=3)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.max_results == 3
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
|
||||
@patch.dict(os.environ, {"BRAVE_SEARCH_API_KEY": "test_api_key"})
|
||||
def test_get_web_search_tool_brave(self):
|
||||
tool = get_web_search_tool(max_search_results=4)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.search_wrapper.api_key == "test_api_key"
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.ARXIV.value)
|
||||
def test_get_web_search_tool_arxiv(self):
|
||||
tool = get_web_search_tool(max_search_results=2)
|
||||
assert tool.name == "web_search"
|
||||
assert tool.api_wrapper.top_k_results == 2
|
||||
assert tool.api_wrapper.load_max_docs == 2
|
||||
assert tool.api_wrapper.load_all_available_meta is True
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", "unsupported_engine")
|
||||
def test_get_web_search_tool_unsupported_engine(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="Unsupported search engine: unsupported_engine"
|
||||
):
|
||||
get_web_search_tool(max_search_results=1)
|
||||
|
||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.BRAVE_SEARCH.value)
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_web_search_tool_brave_no_api_key(self):
|
||||
tool = get_web_search_tool(max_search_results=1)
|
||||
assert tool.search_wrapper.api_key == ""
|
||||
207
tests/unit/tools/test_tavily_search_api_wrapper.py
Normal file
207
tests/unit/tools/test_tavily_search_api_wrapper.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||||
import aiohttp
|
||||
import requests
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
EnhancedTavilySearchAPIWrapper,
|
||||
)
|
||||
|
||||
|
||||
class TestEnhancedTavilySearchAPIWrapper:
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
with patch(
|
||||
"src.tools.tavily_search.tavily_search_api_wrapper.OriginalTavilySearchAPIWrapper"
|
||||
):
|
||||
wrapper = EnhancedTavilySearchAPIWrapper(tavily_api_key="dummy-key")
|
||||
# The parent class is mocked, so initialization won't fail
|
||||
return wrapper
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response_data(self):
|
||||
return {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.9,
|
||||
"raw_content": "Raw test content",
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
{
|
||||
"url": "https://example.com/image.jpg",
|
||||
"description": "Test image description",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
|
||||
def test_raw_results_success(self, mock_post, wrapper, mock_response_data):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = wrapper.raw_results("test query", max_results=10)
|
||||
|
||||
assert result == mock_response_data
|
||||
mock_post.assert_called_once()
|
||||
call_args = mock_post.call_args
|
||||
assert "json" in call_args.kwargs
|
||||
assert call_args.kwargs["json"]["query"] == "test query"
|
||||
assert call_args.kwargs["json"]["max_results"] == 10
|
||||
|
||||
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
|
||||
def test_raw_results_with_all_parameters(
|
||||
self, mock_post, wrapper, mock_response_data
|
||||
):
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = mock_response_data
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
result = wrapper.raw_results(
|
||||
"test query",
|
||||
max_results=3,
|
||||
search_depth="basic",
|
||||
include_domains=["example.com"],
|
||||
exclude_domains=["spam.com"],
|
||||
include_answer=True,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
|
||||
assert result == mock_response_data
|
||||
call_args = mock_post.call_args
|
||||
params = call_args.kwargs["json"]
|
||||
assert params["include_domains"] == ["example.com"]
|
||||
assert params["exclude_domains"] == ["spam.com"]
|
||||
assert params["include_answer"] is True
|
||||
assert params["include_raw_content"] is True
|
||||
|
||||
@patch("src.tools.tavily_search.tavily_search_api_wrapper.requests.post")
|
||||
def test_raw_results_http_error(self, mock_post, wrapper):
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("API Error")
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
with pytest.raises(requests.HTTPError):
|
||||
wrapper.raw_results("test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_results_async_success(self, wrapper, mock_response_data):
|
||||
# Create a mock that acts as both the response and its context manager
|
||||
mock_response_cm = AsyncMock()
|
||||
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
|
||||
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_response_cm.status = 200
|
||||
mock_response_cm.text = AsyncMock(return_value=json.dumps(mock_response_data))
|
||||
|
||||
# Create mock session that returns the context manager
|
||||
mock_session = AsyncMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=mock_response_cm
|
||||
) # Use MagicMock, not AsyncMock
|
||||
|
||||
# Create mock session class
|
||||
mock_session_cm = AsyncMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
|
||||
return_value=mock_session_cm,
|
||||
):
|
||||
result = await wrapper.raw_results_async("test query")
|
||||
|
||||
assert result == mock_response_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_results_async_error(self, wrapper):
|
||||
# Create a mock that acts as both the response and its context manager
|
||||
mock_response_cm = AsyncMock()
|
||||
mock_response_cm.__aenter__ = AsyncMock(return_value=mock_response_cm)
|
||||
mock_response_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_response_cm.status = 400
|
||||
mock_response_cm.reason = "Bad Request"
|
||||
|
||||
# Create mock session that returns the context manager
|
||||
mock_session = AsyncMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=mock_response_cm
|
||||
) # Use MagicMock, not AsyncMock
|
||||
|
||||
# Create mock session class
|
||||
mock_session_cm = AsyncMock()
|
||||
mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session_cm.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
"src.tools.tavily_search.tavily_search_api_wrapper.aiohttp.ClientSession",
|
||||
return_value=mock_session_cm,
|
||||
):
|
||||
with pytest.raises(Exception, match="Error 400: Bad Request"):
|
||||
await wrapper.raw_results_async("test query")
|
||||
|
||||
def test_clean_results_with_images(self, wrapper, mock_response_data):
|
||||
result = wrapper.clean_results_with_images(mock_response_data)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# Test page result
|
||||
page_result = result[0]
|
||||
assert page_result["type"] == "page"
|
||||
assert page_result["title"] == "Test Title"
|
||||
assert page_result["url"] == "https://example.com"
|
||||
assert page_result["content"] == "Test content"
|
||||
assert page_result["score"] == 0.9
|
||||
assert page_result["raw_content"] == "Raw test content"
|
||||
|
||||
# Test image result
|
||||
image_result = result[1]
|
||||
assert image_result["type"] == "image"
|
||||
assert image_result["image_url"] == "https://example.com/image.jpg"
|
||||
assert image_result["image_description"] == "Test image description"
|
||||
|
||||
def test_clean_results_without_raw_content(self, wrapper):
|
||||
data = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.9,
|
||||
}
|
||||
],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
result = wrapper.clean_results_with_images(data)
|
||||
|
||||
assert len(result) == 1
|
||||
assert "raw_content" not in result[0]
|
||||
|
||||
def test_clean_results_empty_images(self, wrapper):
|
||||
data = {
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.9,
|
||||
}
|
||||
],
|
||||
"images": [],
|
||||
}
|
||||
|
||||
result = wrapper.clean_results_with_images(data)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "page"
|
||||
265
tests/unit/tools/test_tavily_search_results_with_images.py
Normal file
265
tests/unit/tools/test_tavily_search_results_with_images.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from typing import Dict, Any
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchResultsWithImages,
|
||||
)
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
EnhancedTavilySearchAPIWrapper,
|
||||
)
|
||||
|
||||
|
||||
class TestTavilySearchResultsWithImages:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_wrapper(self):
|
||||
"""Create a mock API wrapper."""
|
||||
wrapper = Mock(spec=EnhancedTavilySearchAPIWrapper)
|
||||
return wrapper
|
||||
|
||||
@pytest.fixture
|
||||
def search_tool(self, mock_api_wrapper):
|
||||
"""Create a TavilySearchResultsWithImages instance with mocked dependencies."""
|
||||
tool = TavilySearchResultsWithImages(
|
||||
max_results=5,
|
||||
include_answer=True,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
tool.api_wrapper = mock_api_wrapper
|
||||
return tool
|
||||
|
||||
@pytest.fixture
|
||||
def sample_raw_results(self):
|
||||
"""Sample raw results from Tavily API."""
|
||||
return {
|
||||
"query": "test query",
|
||||
"answer": "Test answer",
|
||||
"images": ["https://example.com/image1.jpg"],
|
||||
"results": [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.95,
|
||||
"raw_content": "Raw test content",
|
||||
}
|
||||
],
|
||||
"response_time": 1.5,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_cleaned_results(self):
|
||||
"""Sample cleaned results."""
|
||||
return [
|
||||
{
|
||||
"title": "Test Title",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
}
|
||||
]
|
||||
|
||||
def test_init_default_values(self):
|
||||
"""Test initialization with default values."""
|
||||
tool = TavilySearchResultsWithImages()
|
||||
assert tool.include_image_descriptions is False
|
||||
assert isinstance(tool.api_wrapper, EnhancedTavilySearchAPIWrapper)
|
||||
|
||||
def test_init_custom_values(self):
|
||||
"""Test initialization with custom values."""
|
||||
tool = TavilySearchResultsWithImages(
|
||||
max_results=10, include_image_descriptions=True
|
||||
)
|
||||
assert tool.max_results == 10
|
||||
assert tool.include_image_descriptions is True
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_run_success(
|
||||
self,
|
||||
mock_print,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test successful synchronous run."""
|
||||
mock_api_wrapper.raw_results.return_value = sample_raw_results
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = search_tool._run("test query")
|
||||
|
||||
assert result == sample_cleaned_results
|
||||
assert raw == sample_raw_results
|
||||
|
||||
mock_api_wrapper.raw_results.assert_called_once_with(
|
||||
"test query",
|
||||
search_tool.max_results,
|
||||
search_tool.search_depth,
|
||||
search_tool.include_domains,
|
||||
search_tool.exclude_domains,
|
||||
search_tool.include_answer,
|
||||
search_tool.include_raw_content,
|
||||
search_tool.include_images,
|
||||
search_tool.include_image_descriptions,
|
||||
)
|
||||
|
||||
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
||||
sample_raw_results
|
||||
)
|
||||
mock_print.assert_called_once()
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_run_exception(self, mock_print, search_tool, mock_api_wrapper):
|
||||
"""Test synchronous run with exception."""
|
||||
mock_api_wrapper.raw_results.side_effect = Exception("API Error")
|
||||
|
||||
result, raw = search_tool._run("test query")
|
||||
|
||||
assert "API Error" in result
|
||||
assert raw == {}
|
||||
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("builtins.print")
|
||||
async def test_arun_success(
|
||||
self,
|
||||
mock_print,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test successful asynchronous run."""
|
||||
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = await search_tool._arun("test query")
|
||||
|
||||
assert result == sample_cleaned_results
|
||||
assert raw == sample_raw_results
|
||||
|
||||
mock_api_wrapper.raw_results_async.assert_called_once_with(
|
||||
"test query",
|
||||
search_tool.max_results,
|
||||
search_tool.search_depth,
|
||||
search_tool.include_domains,
|
||||
search_tool.exclude_domains,
|
||||
search_tool.include_answer,
|
||||
search_tool.include_raw_content,
|
||||
search_tool.include_images,
|
||||
search_tool.include_image_descriptions,
|
||||
)
|
||||
|
||||
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
||||
sample_raw_results
|
||||
)
|
||||
mock_print.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("builtins.print")
|
||||
async def test_arun_exception(self, mock_print, search_tool, mock_api_wrapper):
|
||||
"""Test asynchronous run with exception."""
|
||||
mock_api_wrapper.raw_results_async = AsyncMock(
|
||||
side_effect=Exception("Async API Error")
|
||||
)
|
||||
|
||||
result, raw = await search_tool._arun("test query")
|
||||
|
||||
assert "Async API Error" in result
|
||||
assert raw == {}
|
||||
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_run_with_run_manager(
|
||||
self,
|
||||
mock_print,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test run with callback manager."""
|
||||
mock_run_manager = Mock()
|
||||
mock_api_wrapper.raw_results.return_value = sample_raw_results
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = search_tool._run("test query", run_manager=mock_run_manager)
|
||||
|
||||
assert result == sample_cleaned_results
|
||||
assert raw == sample_raw_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("builtins.print")
|
||||
async def test_arun_with_run_manager(
|
||||
self,
|
||||
mock_print,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test async run with callback manager."""
|
||||
mock_run_manager = Mock()
|
||||
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
result, raw = await search_tool._arun(
|
||||
"test query", run_manager=mock_run_manager
|
||||
)
|
||||
|
||||
assert result == sample_cleaned_results
|
||||
assert raw == sample_raw_results
|
||||
|
||||
@patch("builtins.print")
|
||||
def test_print_output_format(
|
||||
self,
|
||||
mock_print,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test that print outputs correctly formatted JSON."""
|
||||
mock_api_wrapper.raw_results.return_value = sample_raw_results
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
search_tool._run("test query")
|
||||
|
||||
# Verify print was called with expected format
|
||||
call_args = mock_print.call_args[0]
|
||||
assert call_args[0] == "sync"
|
||||
assert isinstance(call_args[1], str) # Should be JSON string
|
||||
|
||||
# Verify it's valid JSON
|
||||
json_data = json.loads(call_args[1])
|
||||
assert json_data == sample_cleaned_results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("builtins.print")
|
||||
async def test_async_print_output_format(
|
||||
self,
|
||||
mock_print,
|
||||
search_tool,
|
||||
mock_api_wrapper,
|
||||
sample_raw_results,
|
||||
sample_cleaned_results,
|
||||
):
|
||||
"""Test that async print outputs correctly formatted JSON."""
|
||||
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
|
||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
||||
|
||||
await search_tool._arun("test query")
|
||||
|
||||
# Verify print was called with expected format
|
||||
call_args = mock_print.call_args[0]
|
||||
assert call_args[0] == "async"
|
||||
assert isinstance(call_args[1], str) # Should be JSON string
|
||||
|
||||
# Verify it's valid JSON
|
||||
json_data = json.loads(call_args[1])
|
||||
assert json_data == sample_cleaned_results
|
||||
122
tests/unit/tools/test_tools_retriever.py
Normal file
122
tests/unit/tools/test_tools_retriever.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
AsyncCallbackManagerForToolRun,
|
||||
)
|
||||
import pytest
|
||||
from src.tools.retriever import RetrieverInput, RetrieverTool, get_retriever_tool
|
||||
from src.rag import Document, Retriever, Resource, Chunk
|
||||
|
||||
|
||||
def test_retriever_input_model():
|
||||
input_data = RetrieverInput(keywords="test keywords")
|
||||
assert input_data.keywords == "test keywords"
|
||||
|
||||
|
||||
def test_retriever_tool_init():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
assert tool.name == "local_search_tool"
|
||||
assert "retrieving information" in tool.description
|
||||
assert tool.args_schema == RetrieverInput
|
||||
assert tool.retriever == mock_retriever
|
||||
assert tool.resources == resources
|
||||
|
||||
|
||||
def test_retriever_tool_run_with_results():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
chunk = Chunk(content="test content", similarity=0.9)
|
||||
doc = Document(id="doc1", chunks=[chunk])
|
||||
mock_retriever.query_relevant_documents.return_value = [doc]
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
result = tool._run("test keywords")
|
||||
|
||||
mock_retriever.query_relevant_documents.assert_called_once_with(
|
||||
"test keywords", resources
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0] == doc.to_dict()
|
||||
|
||||
|
||||
def test_retriever_tool_run_no_results():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
mock_retriever.query_relevant_documents.return_value = []
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
result = tool._run("test keywords")
|
||||
|
||||
assert result == "No results found from the local knowledge base."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retriever_tool_arun():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
chunk = Chunk(content="async content", similarity=0.8)
|
||||
doc = Document(id="doc2", chunks=[chunk])
|
||||
mock_retriever.query_relevant_documents.return_value = [doc]
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
mock_run_manager = Mock(spec=AsyncCallbackManagerForToolRun)
|
||||
mock_sync_manager = Mock(spec=CallbackManagerForToolRun)
|
||||
mock_run_manager.get_sync.return_value = mock_sync_manager
|
||||
|
||||
result = await tool._arun("async keywords", mock_run_manager)
|
||||
|
||||
mock_run_manager.get_sync.assert_called_once()
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0] == doc.to_dict()
|
||||
|
||||
|
||||
@patch("src.tools.retriever.build_retriever")
|
||||
def test_get_retriever_tool_success(mock_build_retriever):
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
mock_build_retriever.return_value = mock_retriever
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = get_retriever_tool(resources)
|
||||
|
||||
assert isinstance(tool, RetrieverTool)
|
||||
assert tool.retriever == mock_retriever
|
||||
assert tool.resources == resources
|
||||
|
||||
|
||||
def test_get_retriever_tool_empty_resources():
|
||||
result = get_retriever_tool([])
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("src.tools.retriever.build_retriever")
|
||||
def test_get_retriever_tool_no_retriever(mock_build_retriever):
|
||||
mock_build_retriever.return_value = None
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
result = get_retriever_tool(resources)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_retriever_tool_run_with_callback_manager():
|
||||
mock_retriever = Mock(spec=Retriever)
|
||||
mock_retriever.query_relevant_documents.return_value = []
|
||||
|
||||
resources = [Resource(uri="test://uri", title="Test")]
|
||||
tool = RetrieverTool(retriever=mock_retriever, resources=resources)
|
||||
|
||||
mock_callback_manager = Mock(spec=CallbackManagerForToolRun)
|
||||
result = tool._run("test keywords", mock_callback_manager)
|
||||
|
||||
assert result == "No results found from the local knowledge base."
|
||||
Reference in New Issue
Block a user