mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-22 21:54:45 +08:00
refine the research prompt (#460)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import logging
|
||||||
import json
|
import json
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -15,6 +16,8 @@ from src.tools.tavily_search.tavily_search_api_wrapper import (
|
|||||||
EnhancedTavilySearchAPIWrapper,
|
EnhancedTavilySearchAPIWrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[override, override]
|
class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[override, override]
|
||||||
"""Tool that queries the Tavily Search API and gets back json.
|
"""Tool that queries the Tavily Search API and gets back json.
|
||||||
@@ -123,7 +126,9 @@ class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[overri
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return repr(e), {}
|
return repr(e), {}
|
||||||
cleaned_results = self.api_wrapper.clean_results_with_images(raw_results)
|
cleaned_results = self.api_wrapper.clean_results_with_images(raw_results)
|
||||||
print("sync", json.dumps(cleaned_results, indent=2, ensure_ascii=False))
|
logger.debug(
|
||||||
|
"sync: %s", json.dumps(cleaned_results, indent=2, ensure_ascii=False)
|
||||||
|
)
|
||||||
return cleaned_results, raw_results
|
return cleaned_results, raw_results
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
@@ -147,5 +152,7 @@ class TavilySearchResultsWithImages(TavilySearchResults): # type: ignore[overri
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return repr(e), {}
|
return repr(e), {}
|
||||||
cleaned_results = self.api_wrapper.clean_results_with_images(raw_results)
|
cleaned_results = self.api_wrapper.clean_results_with_images(raw_results)
|
||||||
print("async", json.dumps(cleaned_results, indent=2, ensure_ascii=False))
|
logger.debug(
|
||||||
|
"async: %s", json.dumps(cleaned_results, indent=2, ensure_ascii=False)
|
||||||
|
)
|
||||||
return cleaned_results, raw_results
|
return cleaned_results, raw_results
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
import json
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, patch, AsyncMock
|
from unittest.mock import Mock, AsyncMock
|
||||||
from src.tools.tavily_search.tavily_search_results_with_images import (
|
from src.tools.tavily_search.tavily_search_results_with_images import (
|
||||||
TavilySearchResultsWithImages,
|
TavilySearchResultsWithImages,
|
||||||
)
|
)
|
||||||
@@ -77,10 +76,8 @@ class TestTavilySearchResultsWithImages:
|
|||||||
assert tool.max_results == 10
|
assert tool.max_results == 10
|
||||||
assert tool.include_image_descriptions is True
|
assert tool.include_image_descriptions is True
|
||||||
|
|
||||||
@patch("builtins.print")
|
|
||||||
def test_run_success(
|
def test_run_success(
|
||||||
self,
|
self,
|
||||||
mock_print,
|
|
||||||
search_tool,
|
search_tool,
|
||||||
mock_api_wrapper,
|
mock_api_wrapper,
|
||||||
sample_raw_results,
|
sample_raw_results,
|
||||||
@@ -110,10 +107,8 @@ class TestTavilySearchResultsWithImages:
|
|||||||
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
||||||
sample_raw_results
|
sample_raw_results
|
||||||
)
|
)
|
||||||
mock_print.assert_called_once()
|
|
||||||
|
|
||||||
@patch("builtins.print")
|
def test_run_exception(self, search_tool, mock_api_wrapper):
|
||||||
def test_run_exception(self, mock_print, search_tool, mock_api_wrapper):
|
|
||||||
"""Test synchronous run with exception."""
|
"""Test synchronous run with exception."""
|
||||||
mock_api_wrapper.raw_results.side_effect = Exception("API Error")
|
mock_api_wrapper.raw_results.side_effect = Exception("API Error")
|
||||||
|
|
||||||
@@ -124,10 +119,8 @@ class TestTavilySearchResultsWithImages:
|
|||||||
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("builtins.print")
|
|
||||||
async def test_arun_success(
|
async def test_arun_success(
|
||||||
self,
|
self,
|
||||||
mock_print,
|
|
||||||
search_tool,
|
search_tool,
|
||||||
mock_api_wrapper,
|
mock_api_wrapper,
|
||||||
sample_raw_results,
|
sample_raw_results,
|
||||||
@@ -157,11 +150,9 @@ class TestTavilySearchResultsWithImages:
|
|||||||
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
mock_api_wrapper.clean_results_with_images.assert_called_once_with(
|
||||||
sample_raw_results
|
sample_raw_results
|
||||||
)
|
)
|
||||||
mock_print.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("builtins.print")
|
async def test_arun_exception(self, search_tool, mock_api_wrapper):
|
||||||
async def test_arun_exception(self, mock_print, search_tool, mock_api_wrapper):
|
|
||||||
"""Test asynchronous run with exception."""
|
"""Test asynchronous run with exception."""
|
||||||
mock_api_wrapper.raw_results_async = AsyncMock(
|
mock_api_wrapper.raw_results_async = AsyncMock(
|
||||||
side_effect=Exception("Async API Error")
|
side_effect=Exception("Async API Error")
|
||||||
@@ -173,10 +164,8 @@ class TestTavilySearchResultsWithImages:
|
|||||||
assert raw == {}
|
assert raw == {}
|
||||||
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
mock_api_wrapper.clean_results_with_images.assert_not_called()
|
||||||
|
|
||||||
@patch("builtins.print")
|
|
||||||
def test_run_with_run_manager(
|
def test_run_with_run_manager(
|
||||||
self,
|
self,
|
||||||
mock_print,
|
|
||||||
search_tool,
|
search_tool,
|
||||||
mock_api_wrapper,
|
mock_api_wrapper,
|
||||||
sample_raw_results,
|
sample_raw_results,
|
||||||
@@ -193,10 +182,8 @@ class TestTavilySearchResultsWithImages:
|
|||||||
assert raw == sample_raw_results
|
assert raw == sample_raw_results
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch("builtins.print")
|
|
||||||
async def test_arun_with_run_manager(
|
async def test_arun_with_run_manager(
|
||||||
self,
|
self,
|
||||||
mock_print,
|
|
||||||
search_tool,
|
search_tool,
|
||||||
mock_api_wrapper,
|
mock_api_wrapper,
|
||||||
sample_raw_results,
|
sample_raw_results,
|
||||||
@@ -213,52 +200,3 @@ class TestTavilySearchResultsWithImages:
|
|||||||
|
|
||||||
assert result == sample_cleaned_results
|
assert result == sample_cleaned_results
|
||||||
assert raw == sample_raw_results
|
assert raw == sample_raw_results
|
||||||
|
|
||||||
@patch("builtins.print")
|
|
||||||
def test_print_output_format(
|
|
||||||
self,
|
|
||||||
mock_print,
|
|
||||||
search_tool,
|
|
||||||
mock_api_wrapper,
|
|
||||||
sample_raw_results,
|
|
||||||
sample_cleaned_results,
|
|
||||||
):
|
|
||||||
"""Test that print outputs correctly formatted JSON."""
|
|
||||||
mock_api_wrapper.raw_results.return_value = sample_raw_results
|
|
||||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
|
||||||
|
|
||||||
search_tool._run("test query")
|
|
||||||
|
|
||||||
# Verify print was called with expected format
|
|
||||||
call_args = mock_print.call_args[0]
|
|
||||||
assert call_args[0] == "sync"
|
|
||||||
assert isinstance(call_args[1], str) # Should be JSON string
|
|
||||||
|
|
||||||
# Verify it's valid JSON
|
|
||||||
json_data = json.loads(call_args[1])
|
|
||||||
assert json_data == sample_cleaned_results
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("builtins.print")
|
|
||||||
async def test_async_print_output_format(
|
|
||||||
self,
|
|
||||||
mock_print,
|
|
||||||
search_tool,
|
|
||||||
mock_api_wrapper,
|
|
||||||
sample_raw_results,
|
|
||||||
sample_cleaned_results,
|
|
||||||
):
|
|
||||||
"""Test that async print outputs correctly formatted JSON."""
|
|
||||||
mock_api_wrapper.raw_results_async = AsyncMock(return_value=sample_raw_results)
|
|
||||||
mock_api_wrapper.clean_results_with_images.return_value = sample_cleaned_results
|
|
||||||
|
|
||||||
await search_tool._arun("test query")
|
|
||||||
|
|
||||||
# Verify print was called with expected format
|
|
||||||
call_args = mock_print.call_args[0]
|
|
||||||
assert call_args[0] == "async"
|
|
||||||
assert isinstance(call_args[1], str) # Should be JSON string
|
|
||||||
|
|
||||||
# Verify it's valid JSON
|
|
||||||
json_data = json.loads(call_args[1])
|
|
||||||
assert json_data == sample_cleaned_results
|
|
||||||
|
|||||||
Reference in New Issue
Block a user