2025-04-17 11:34:42 +08:00
|
|
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
|
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
2025-04-19 09:57:02 +08:00
|
|
|
import json
|
2025-04-07 16:25:55 +08:00
|
|
|
import logging
|
2025-04-11 15:37:55 +08:00
|
|
|
import os
|
2025-04-19 09:57:02 +08:00
|
|
|
|
|
|
|
|
from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
|
2025-04-11 15:37:55 +08:00
|
|
|
from langchain_community.tools.arxiv import ArxivQueryRun
|
|
|
|
|
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
|
2025-04-19 09:57:02 +08:00
|
|
|
|
2025-05-17 22:23:52 -07:00
|
|
|
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
|
2025-04-19 09:57:02 +08:00
|
|
|
from src.tools.tavily_search.tavily_search_results_with_images import (
|
|
|
|
|
TavilySearchResultsWithImages,
|
|
|
|
|
)
|
|
|
|
|
|
2025-05-12 20:15:47 +08:00
|
|
|
from src.tools.decorators import create_logged_tool
|
2025-04-07 16:25:55 +08:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2025-05-17 22:23:52 -07:00
|
|
|
# Create logged versions of the search tools
|
2025-04-19 09:57:02 +08:00
|
|
|
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
|
2025-04-10 11:45:04 +08:00
|
|
|
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
|
2025-04-11 15:37:55 +08:00
|
|
|
LoggedBraveSearch = create_logged_tool(BraveSearch)
|
|
|
|
|
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
|
2025-05-17 22:23:52 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# Get the selected search tool
|
|
|
|
|
def get_web_search_tool(max_search_results: int):
|
|
|
|
|
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
|
|
|
|
|
return LoggedTavilySearch(
|
|
|
|
|
name="web_search",
|
|
|
|
|
max_results=max_search_results,
|
|
|
|
|
include_raw_content=True,
|
|
|
|
|
include_images=True,
|
|
|
|
|
include_image_descriptions=True,
|
|
|
|
|
)
|
|
|
|
|
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
|
|
|
|
|
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
|
|
|
|
|
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
|
|
|
|
|
return LoggedBraveSearch(
|
|
|
|
|
name="web_search",
|
|
|
|
|
search_wrapper=BraveSearchWrapper(
|
|
|
|
|
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
|
|
|
|
|
search_kwargs={"count": max_search_results},
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value:
|
|
|
|
|
return LoggedArxivSearch(
|
|
|
|
|
name="web_search",
|
|
|
|
|
api_wrapper=ArxivAPIWrapper(
|
|
|
|
|
top_k_results=max_search_results,
|
|
|
|
|
load_max_docs=max_search_results,
|
|
|
|
|
load_all_available_meta=True,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
|
|
|
|
|
|
2025-04-19 09:57:02 +08:00
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2025-05-12 20:15:47 +08:00
|
|
|
results = LoggedDuckDuckGoSearch(
|
2025-05-17 22:23:52 -07:00
|
|
|
name="web_search", max_results=3, output_format="list"
|
2025-05-12 20:15:47 +08:00
|
|
|
).invoke("cute panda")
|
2025-04-19 09:57:02 +08:00
|
|
|
print(json.dumps(results, indent=2, ensure_ascii=False))
|