mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-12 10:04:45 +08:00
feat: support images in the search results
This commit is contained in:
@@ -1,25 +1,31 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from langchain_community.tools.tavily_search import TavilySearchResults
|
||||
from langchain_community.tools import DuckDuckGoSearchResults
|
||||
from langchain_community.tools import BraveSearch
|
||||
|
||||
from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
|
||||
from langchain_community.tools.arxiv import ArxivQueryRun
|
||||
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
|
||||
|
||||
from src.config import SEARCH_MAX_RESULTS
|
||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||
TavilySearchResultsWithImages,
|
||||
)
|
||||
|
||||
from .decorators import create_logged_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LoggedTavilySearch = create_logged_tool(TavilySearchResults)
|
||||
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
|
||||
tavily_search_tool = LoggedTavilySearch(
|
||||
name="web_search",
|
||||
max_results=SEARCH_MAX_RESULTS,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
)
|
||||
|
||||
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
|
||||
@@ -45,3 +51,7 @@ arxiv_search_tool = LoggedArxivSearch(
|
||||
load_all_available_meta=True,
|
||||
),
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = tavily_search_tool.invoke("cute panda")
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
|
||||
4
src/tools/tavily_search/__init__.py
Normal file
4
src/tools/tavily_search/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .tavily_search_api_wrapper import EnhancedTavilySearchAPIWrapper
|
||||
from .tavily_search_results_with_images import TavilySearchResultsWithImages
|
||||
|
||||
__all__ = ["EnhancedTavilySearchAPIWrapper", "TavilySearchResultsWithImages"]
|
||||
115
src/tools/tavily_search/tavily_search_api_wrapper.py
Normal file
115
src/tools/tavily_search/tavily_search_api_wrapper.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_community.utilities.tavily_search import TAVILY_API_URL
|
||||
from langchain_community.utilities.tavily_search import (
|
||||
TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper,
|
||||
)
|
||||
|
||||
|
||||
class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
|
||||
def raw_results(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[List[str]] = [],
|
||||
exclude_domains: Optional[List[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
include_image_descriptions: Optional[bool] = False,
|
||||
) -> Dict:
|
||||
params = {
|
||||
"api_key": self.tavily_api_key.get_secret_value(),
|
||||
"query": query,
|
||||
"max_results": max_results,
|
||||
"search_depth": search_depth,
|
||||
"include_domains": include_domains,
|
||||
"exclude_domains": exclude_domains,
|
||||
"include_answer": include_answer,
|
||||
"include_raw_content": include_raw_content,
|
||||
"include_images": include_images,
|
||||
"include_image_descriptions": include_image_descriptions,
|
||||
}
|
||||
response = requests.post(
|
||||
# type: ignore
|
||||
f"{TAVILY_API_URL}/search",
|
||||
json=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def raw_results_async(
|
||||
self,
|
||||
query: str,
|
||||
max_results: Optional[int] = 5,
|
||||
search_depth: Optional[str] = "advanced",
|
||||
include_domains: Optional[List[str]] = [],
|
||||
exclude_domains: Optional[List[str]] = [],
|
||||
include_answer: Optional[bool] = False,
|
||||
include_raw_content: Optional[bool] = False,
|
||||
include_images: Optional[bool] = False,
|
||||
include_image_descriptions: Optional[bool] = False,
|
||||
) -> Dict:
|
||||
"""Get results from the Tavily Search API asynchronously."""
|
||||
|
||||
# Function to perform the API call
|
||||
async def fetch() -> str:
|
||||
params = {
|
||||
"api_key": self.tavily_api_key.get_secret_value(),
|
||||
"query": query,
|
||||
"max_results": max_results,
|
||||
"search_depth": search_depth,
|
||||
"include_domains": include_domains,
|
||||
"exclude_domains": exclude_domains,
|
||||
"include_answer": include_answer,
|
||||
"include_raw_content": include_raw_content,
|
||||
"include_images": include_images,
|
||||
"include_image_descriptions": include_image_descriptions,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{TAVILY_API_URL}/search", json=params) as res:
|
||||
if res.status == 200:
|
||||
data = await res.text()
|
||||
return data
|
||||
else:
|
||||
raise Exception(f"Error {res.status}: {res.reason}")
|
||||
|
||||
results_json_str = await fetch()
|
||||
return json.loads(results_json_str)
|
||||
|
||||
def clean_results_with_images(
|
||||
self, raw_results: Dict[str, List[Dict]]
|
||||
) -> List[Dict]:
|
||||
results = raw_results["results"]
|
||||
"""Clean results from Tavily Search API."""
|
||||
clean_results = []
|
||||
for result in results:
|
||||
clean_result = {
|
||||
"type": "page",
|
||||
"title": result["title"],
|
||||
"url": result["url"],
|
||||
"content": result["content"],
|
||||
"score": result["score"],
|
||||
}
|
||||
if raw_content := result.get("raw_content"):
|
||||
clean_result["raw_content"] = raw_content
|
||||
clean_results.append(clean_result)
|
||||
images = raw_results["images"]
|
||||
for image in images:
|
||||
clean_result = {
|
||||
"type": "image",
|
||||
"image_url": image["url"],
|
||||
"image_description": image["description"],
|
||||
}
|
||||
clean_results.append(clean_result)
|
||||
return clean_results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wrapper = EnhancedTavilySearchAPIWrapper()
|
||||
results = wrapper.raw_results("cute panda", include_images=True)
|
||||
print(json.dumps(results, indent=2, ensure_ascii=False))
|
||||
148
src/tools/tavily_search/tavily_search_results_with_images.py
Normal file
148
src/tools/tavily_search/tavily_search_results_with_images.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_community.tools.tavily_search.tool import TavilySearchResults
|
||||
from pydantic import Field
|
||||
|
||||
from src.tools.tavily_search.tavily_search_api_wrapper import (
|
||||
EnhancedTavilySearchAPIWrapper,
|
||||
)
|
||||
|
||||
|
||||
class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[override, override]
|
||||
"""Tool that queries the Tavily Search API and gets back json.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-openai`` and ``tavily-python``, and set environment variable ``TAVILY_API_KEY``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-community tavily-python
|
||||
export TAVILY_API_KEY="your-api-key"
|
||||
|
||||
Instantiate:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.tools import TavilySearchResults
|
||||
|
||||
tool = TavilySearchResults(
|
||||
max_results=5,
|
||||
include_answer=True,
|
||||
include_raw_content=True,
|
||||
include_images=True,
|
||||
include_image_descriptions=True,
|
||||
# search_depth="advanced",
|
||||
# include_domains = []
|
||||
# exclude_domains = []
|
||||
)
|
||||
|
||||
Invoke directly with args:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tool.invoke({'query': 'who won the last french open'})
|
||||
|
||||
.. code-block:: json
|
||||
|
||||
{
|
||||
"url": "https://www.nytimes.com...",
|
||||
"content": "Novak Djokovic won the last French Open by beating Casper Ruud ..."
|
||||
}
|
||||
|
||||
Invoke with tool call:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
tool.invoke({"args": {'query': 'who won the last french open'}, "type": "tool_call", "id": "foo", "name": "tavily"})
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ToolMessage(
|
||||
content='{ "url": "https://www.nytimes.com...", "content": "Novak Djokovic won the last French Open by beating Casper Ruud ..." }',
|
||||
artifact={
|
||||
'query': 'who won the last french open',
|
||||
'follow_up_questions': None,
|
||||
'answer': 'Novak ...',
|
||||
'images': [
|
||||
'https://www.amny.com/wp-content/uploads/2023/06/AP23162622181176-1200x800.jpg',
|
||||
...
|
||||
],
|
||||
'results': [
|
||||
{
|
||||
'title': 'Djokovic ...',
|
||||
'url': 'https://www.nytimes.com...',
|
||||
'content': "Novak...",
|
||||
'score': 0.99505633,
|
||||
'raw_content': 'Tennis\nNovak ...'
|
||||
},
|
||||
...
|
||||
],
|
||||
'response_time': 2.92
|
||||
},
|
||||
tool_call_id='1',
|
||||
name='tavily_search_results_json',
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
include_image_descriptions: bool = False
|
||||
"""Include a image descriptions in the response.
|
||||
|
||||
Default is False.
|
||||
"""
|
||||
|
||||
api_wrapper: EnhancedTavilySearchAPIWrapper = Field(default_factory=EnhancedTavilySearchAPIWrapper) # type: ignore[arg-type]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
|
||||
"""Use the tool."""
|
||||
# TODO: remove try/except, should be handled by BaseTool
|
||||
try:
|
||||
raw_results = self.api_wrapper.raw_results(
|
||||
query,
|
||||
self.max_results,
|
||||
self.search_depth,
|
||||
self.include_domains,
|
||||
self.exclude_domains,
|
||||
self.include_answer,
|
||||
self.include_raw_content,
|
||||
self.include_images,
|
||||
self.include_image_descriptions,
|
||||
)
|
||||
except Exception as e:
|
||||
return repr(e), {}
|
||||
cleaned_results = self.api_wrapper.clean_results_with_images(raw_results)
|
||||
print("sync", json.dumps(cleaned_results, indent=2, ensure_ascii=False))
|
||||
return cleaned_results, raw_results
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Tuple[Union[List[Dict[str, str]], str], Dict]:
|
||||
"""Use the tool asynchronously."""
|
||||
try:
|
||||
raw_results = await self.api_wrapper.raw_results_async(
|
||||
query,
|
||||
self.max_results,
|
||||
self.search_depth,
|
||||
self.include_domains,
|
||||
self.exclude_domains,
|
||||
self.include_answer,
|
||||
self.include_raw_content,
|
||||
self.include_images,
|
||||
self.include_image_descriptions,
|
||||
)
|
||||
except Exception as e:
|
||||
return repr(e), {}
|
||||
cleaned_results = self.api_wrapper.clean_results_with_images(raw_results)
|
||||
print("async", json.dumps(cleaned_results, indent=2, ensure_ascii=False))
|
||||
return cleaned_results, raw_results
|
||||
Reference in New Issue
Block a user