test: add more unit tests of tools (#315)

* test: add more test on test_tts.py

* test: add unit test of search and retriever in tools

* test: remove the main code of search.py

* test: add the travily_search unit test

* reformate the codes

* test: add unit tests of tools

* Added the pytest-asyncio dependency

* added the license header of test_tavily_search_api_wrapper.py
This commit is contained in:
Willem Jiang
2025-06-12 20:43:32 +08:00
committed by GitHub
parent bb7dc6e98c
commit 4c2fe2e7f5
14 changed files with 1057 additions and 35 deletions

View File

@@ -1,8 +1,8 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .retriever import Retriever, Document, Resource
from .retriever import Retriever, Document, Resource, Chunk
from .ragflow import RAGFlowProvider
from .builder import build_retriever
__all__ = [Retriever, Document, Resource, RAGFlowProvider, build_retriever]
__all__ = [Retriever, Document, Resource, RAGFlowProvider, Chunk, build_retriever]

View File

@@ -60,18 +60,3 @@ def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
if not retriever:
return None
return RetrieverTool(retriever=retriever, resources=resources)
if __name__ == "__main__":
resources = [
Resource(
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
title="西游记",
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
)
]
retriever_tool = get_retriever_tool(resources)
print(retriever_tool.name)
print(retriever_tool.description)
print(retriever_tool.args)
print(retriever_tool.invoke("三打白骨精"))

View File

@@ -36,7 +36,10 @@ def get_web_search_tool(max_search_results: int):
include_image_descriptions=True,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
return LoggedDuckDuckGoSearch(
name="web_search",
num_results=max_search_results,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
return LoggedBraveSearch(
name="web_search",
@@ -56,14 +59,3 @@ def get_web_search_tool(max_search_results: int):
)
else:
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
if __name__ == "__main__":
results = LoggedDuckDuckGoSearch(
name="web_search", max_results=3, output_format="list"
)
print(results.name)
print(results.description)
print(results.args)
# .invoke("cute panda")
# print(json.dumps(results, indent=2, ensure_ascii=False))

View File

@@ -1,3 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from typing import Dict, List, Optional
@@ -107,9 +111,3 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
}
clean_results.append(clean_result)
return clean_results
if __name__ == "__main__":
wrapper = EnhancedTavilySearchAPIWrapper()
results = wrapper.raw_results("cute panda", include_images=True)
print(json.dumps(results, indent=2, ensure_ascii=False))

View File

@@ -1,3 +1,6 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
from typing import Dict, List, Optional, Tuple, Union