Files
deer-flow/src/tools/search.py

81 lines
2.9 KiB
Python
Raw Normal View History

2025-04-17 11:34:42 +08:00
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
2025-04-11 15:37:55 +08:00
import os
from typing import List, Optional
from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
2025-04-11 15:37:55 +08:00
from langchain_community.tools.arxiv import ArxivQueryRun
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
from src.config import load_yaml_config
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
from src.tools.decorators import create_logged_tool
logger = logging.getLogger(__name__)
# Create logged versions of the search tools
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
2025-04-10 11:45:04 +08:00
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
2025-04-11 15:37:55 +08:00
LoggedBraveSearch = create_logged_tool(BraveSearch)
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
def get_search_config():
config = load_yaml_config("conf.yaml")
search_config = config.get("SEARCH_ENGINE", {})
return search_config
# Get the selected search tool
def get_web_search_tool(max_search_results: int):
search_config = get_search_config()
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
# Only get and apply include/exclude domains for Tavily
include_domains: Optional[List[str]] = search_config.get("include_domains", [])
exclude_domains: Optional[List[str]] = search_config.get("exclude_domains", [])
logger.info(
f"Tavily search configuration loaded: include_domains={include_domains}, exclude_domains={exclude_domains}"
)
return LoggedTavilySearch(
name="web_search",
max_results=max_search_results,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
include_domains=include_domains,
exclude_domains=exclude_domains,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
return LoggedDuckDuckGoSearch(
name="web_search",
num_results=max_search_results,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
return LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": max_search_results},
),
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value:
return LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=max_search_results,
load_max_docs=max_search_results,
load_all_available_meta=True,
),
)
else:
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")