mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-10 17:24:45 +08:00
test: add more unit tests of tools (#315)
* test: add more test on test_tts.py * test: add unit test of search and retriever in tools * test: remove the main code of search.py * test: add the travily_search unit test * reformate the codes * test: add unit tests of tools * Added the pytest-asyncio dependency * added the license header of test_tavily_search_api_wrapper.py
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .retriever import Retriever, Document, Resource
|
||||
from .retriever import Retriever, Document, Resource, Chunk
|
||||
from .ragflow import RAGFlowProvider
|
||||
from .builder import build_retriever
|
||||
|
||||
__all__ = [Retriever, Document, Resource, RAGFlowProvider, build_retriever]
|
||||
__all__ = [Retriever, Document, Resource, RAGFlowProvider, Chunk, build_retriever]
|
||||
|
||||
@@ -60,18 +60,3 @@ def get_retriever_tool(resources: List[Resource]) -> RetrieverTool | None:
|
||||
if not retriever:
|
||||
return None
|
||||
return RetrieverTool(retriever=retriever, resources=resources)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
resources = [
|
||||
Resource(
|
||||
uri="rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0",
|
||||
title="西游记",
|
||||
description="西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。",
|
||||
)
|
||||
]
|
||||
retriever_tool = get_retriever_tool(resources)
|
||||
print(retriever_tool.name)
|
||||
print(retriever_tool.description)
|
||||
print(retriever_tool.args)
|
||||
print(retriever_tool.invoke("三打白骨精"))
|
||||
|
||||
@@ -36,7 +36,10 @@ def get_web_search_tool(max_search_results: int):
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
|
||||
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
|
||||
return LoggedDuckDuckGoSearch(
|
||||
name="web_search",
|
||||
num_results=max_search_results,
|
||||
)
|
||||
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
|
||||
return LoggedBraveSearch(
|
||||
name="web_search",
|
||||
@@ -56,14 +59,3 @@ def get_web_search_tool(max_search_results: int):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = LoggedDuckDuckGoSearch(
|
||||
name="web_search", max_results=3, output_format="list"
|
||||
)
|
||||
print(results.name)
|
||||
print(results.description)
|
||||
print(results.args)
|
||||
# .invoke("cute panda")
|
||||
# print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@@ -107,9 +111,3 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
|
||||
}
|
||||
clean_results.append(clean_result)
|
||||
return clean_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wrapper = EnhancedTavilySearchAPIWrapper()
|
||||
results = wrapper.raw_results("cute panda", include_images=True)
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
Reference in New Issue
Block a user