# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT import logging import os from typing import List, Optional from langchain_community.tools import ( BraveSearch, DuckDuckGoSearchResults, SearxSearchRun, WikipediaQueryRun, ) from langchain_community.tools.arxiv import ArxivQueryRun from langchain_community.utilities import ( ArxivAPIWrapper, BraveSearchWrapper, SearxSearchWrapper, WikipediaAPIWrapper, ) from src.config import SELECTED_SEARCH_ENGINE, SearchEngine, load_yaml_config from src.tools.decorators import create_logged_tool from src.tools.tavily_search.tavily_search_results_with_images import ( TavilySearchWithImages, ) logger = logging.getLogger(__name__) # Create logged versions of the search tools LoggedTavilySearch = create_logged_tool(TavilySearchWithImages) LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults) LoggedBraveSearch = create_logged_tool(BraveSearch) LoggedArxivSearch = create_logged_tool(ArxivQueryRun) LoggedSearxSearch = create_logged_tool(SearxSearchRun) LoggedWikipediaSearch = create_logged_tool(WikipediaQueryRun) def get_search_config(): config = load_yaml_config("conf.yaml") search_config = config.get("SEARCH_ENGINE", {}) return search_config # Get the selected search tool def get_web_search_tool(max_search_results: int): search_config = get_search_config() if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value: # Only get and apply include/exclude domains for Tavily include_domains: Optional[List[str]] = search_config.get("include_domains", []) exclude_domains: Optional[List[str]] = search_config.get("exclude_domains", []) include_raw_content = search_config.get("include_raw_content", True) include_images: Optional[bool] = search_config.get("include_images", True) include_image_descriptions: Optional[bool] = ( include_images and search_config.get("include_image_descriptions", True) ) logger.info( f"Tavily search configuration loaded: include_domains={include_domains}, exclude_domains={exclude_domains}" ) return LoggedTavilySearch( name="web_search", max_results=max_search_results, include_raw_content=include_raw_content, include_images=include_images, include_image_descriptions=include_image_descriptions, include_domains=include_domains, exclude_domains=exclude_domains, ) elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value: return LoggedDuckDuckGoSearch( name="web_search", num_results=max_search_results, ) elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value: return LoggedBraveSearch( name="web_search", search_wrapper=BraveSearchWrapper( api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""), search_kwargs={"count": max_search_results}, ), ) elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value: return LoggedArxivSearch( name="web_search", api_wrapper=ArxivAPIWrapper( top_k_results=max_search_results, load_max_docs=max_search_results, load_all_available_meta=True, ), ) elif SELECTED_SEARCH_ENGINE == SearchEngine.SEARX.value: return LoggedSearxSearch( name="web_search", wrapper=SearxSearchWrapper( k=max_search_results, ), ) elif SELECTED_SEARCH_ENGINE == SearchEngine.WIKIPEDIA.value: wiki_lang = search_config.get("wikipedia_lang", "en") wiki_doc_content_chars_max = search_config.get( "wikipedia_doc_content_chars_max", 4000 ) return LoggedWikipediaSearch( name="web_search", api_wrapper=WikipediaAPIWrapper( lang=wiki_lang, top_k_results=max_search_results, load_all_available_meta=True, doc_content_chars_max=wiki_doc_content_chars_max, ), ) else: raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")