mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-29 00:34:47 +08:00
feat(context): decrease token in web_search AIMessage (#827)
This PR addresses token limit issues when web_search is enabled with include_raw_content by implementing a two-pronged approach: changing the default behavior to exclude raw content and adding compression logic for when raw content is included.
This commit is contained in:
@@ -37,13 +37,13 @@ def crawl_tool(
|
|||||||
"error": "PDF files cannot be crawled directly. Please download and view the PDF manually.",
|
"error": "PDF files cannot be crawled directly. Please download and view the PDF manually.",
|
||||||
"crawled_content": None,
|
"crawled_content": None,
|
||||||
"is_pdf": True
|
"is_pdf": True
|
||||||
})
|
}, ensure_ascii=False)
|
||||||
return pdf_message
|
return pdf_message
|
||||||
|
|
||||||
try:
|
try:
|
||||||
crawler = Crawler()
|
crawler = Crawler()
|
||||||
article = crawler.crawl(url)
|
article = crawler.crawl(url)
|
||||||
return json.dumps({"url": url, "crawled_content": article.to_markdown()[:1000]})
|
return json.dumps({"url": url, "crawled_content": article.to_markdown()[:1000]}, ensure_ascii=False)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
error_msg = f"Failed to crawl. Error: {repr(e)}"
|
error_msg = f"Failed to crawl. Error: {repr(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ def get_web_search_tool(max_search_results: int):
|
|||||||
exclude_domains: Optional[List[str]] = search_config.get("exclude_domains", [])
|
exclude_domains: Optional[List[str]] = search_config.get("exclude_domains", [])
|
||||||
include_answer: bool = search_config.get("include_answer", False)
|
include_answer: bool = search_config.get("include_answer", False)
|
||||||
search_depth: str = search_config.get("search_depth", "advanced")
|
search_depth: str = search_config.get("search_depth", "advanced")
|
||||||
include_raw_content: bool = search_config.get("include_raw_content", True)
|
include_raw_content: bool = search_config.get("include_raw_content", False)
|
||||||
include_images: bool = search_config.get("include_images", True)
|
include_images: bool = search_config.get("include_images", True)
|
||||||
include_image_descriptions: bool = include_images and search_config.get(
|
include_image_descriptions: bool = include_images and search_config.get(
|
||||||
"include_image_descriptions", True
|
"include_image_descriptions", True
|
||||||
|
|||||||
@@ -188,77 +188,86 @@ class ContextManager:
|
|||||||
|
|
||||||
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||||
"""
|
"""
|
||||||
Compress compressible messages
|
Compress messages to fit within token limit through two strategies:
|
||||||
|
1. First, compress web_search ToolMessage raw_content by truncating to 1024 chars
|
||||||
|
2. If still over limit, drop oldest messages while preserving prefix messages and system messages
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of messages to compress
|
messages: List of messages to compress
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Compressed message list
|
List of messages with compressed content and/or dropped messages
|
||||||
"""
|
"""
|
||||||
|
# Create a deep copy to avoid mutating original messages
|
||||||
|
compressed = copy.deepcopy(messages)
|
||||||
|
|
||||||
|
# Step 1: Compress raw_content in web_search ToolMessages
|
||||||
|
for msg in compressed:
|
||||||
|
# Only compress ToolMessage with name 'web_search'
|
||||||
|
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
|
||||||
|
try:
|
||||||
|
# Determine content type and check if compression is needed
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
# Early exit if content is small enough (avoid JSON parsing overhead)
|
||||||
|
# A heuristic: if string is less than 2KB, raw_content likely doesn't need truncation
|
||||||
|
if len(msg.content) < 2048:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_data = json.loads(msg.content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Failed to parse JSON content in web_search ToolMessage: {e}. Content: {msg.content[:200]}")
|
||||||
|
continue
|
||||||
|
elif isinstance(msg.content, list):
|
||||||
|
content_data = copy.deepcopy(msg.content)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
available_token = self.token_limit
|
# Compress raw_content in the content (item by item processing)
|
||||||
prefix_messages = []
|
# Track if any modifications were made
|
||||||
|
modified = False
|
||||||
|
if isinstance(content_data, list):
|
||||||
|
for item in content_data:
|
||||||
|
if isinstance(item, dict) and "raw_content" in item:
|
||||||
|
raw_content = item.get("raw_content")
|
||||||
|
if raw_content and isinstance(raw_content, str) and len(raw_content) > 1024:
|
||||||
|
item["raw_content"] = raw_content[:1024]
|
||||||
|
modified = True
|
||||||
|
|
||||||
|
# Update message content with modified data only if changes were made
|
||||||
|
if modified:
|
||||||
|
msg.content = json.dumps(content_data, ensure_ascii=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during message compression: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
# 1. Preserve head messages of specified length to retain system prompts and user input
|
# Step 2: If still over limit after raw_content compression, drop oldest messages
|
||||||
for i in range(min(self.preserve_prefix_message_count, len(messages))):
|
# while preserving prefix messages (e.g., system message) and recent messages
|
||||||
cur_token_cnt = self._count_message_tokens(messages[i])
|
if self.is_over_limit(compressed):
|
||||||
if available_token > 0 and available_token >= cur_token_cnt:
|
# Identify messages to preserve at the beginning
|
||||||
prefix_messages.append(messages[i])
|
preserved_count = self.preserve_prefix_message_count
|
||||||
available_token -= cur_token_cnt
|
preserved_messages = compressed[:preserved_count]
|
||||||
elif available_token > 0:
|
remaining_messages = compressed[preserved_count:]
|
||||||
# Truncate content to fit available tokens
|
|
||||||
truncated_message = self._truncate_message_content(
|
# Drop messages from the middle, keeping the most recent ones
|
||||||
messages[i], available_token
|
result_messages = preserved_messages
|
||||||
)
|
for msg in reversed(remaining_messages):
|
||||||
prefix_messages.append(truncated_message)
|
result_messages.insert(len(preserved_messages), msg)
|
||||||
return prefix_messages
|
if not self.is_over_limit(result_messages):
|
||||||
else:
|
break
|
||||||
break
|
|
||||||
|
compressed = result_messages
|
||||||
|
|
||||||
# 2. Compress subsequent messages from the tail, some messages may be discarded
|
# Step 3: Verify that compression was successful and log warning if needed
|
||||||
messages = messages[len(prefix_messages) :]
|
if self.is_over_limit(compressed):
|
||||||
suffix_messages = []
|
current_tokens = self.count_tokens(compressed)
|
||||||
for i in range(len(messages) - 1, -1, -1):
|
logger.warning(
|
||||||
cur_token_cnt = self._count_message_tokens(messages[i])
|
f"Message compression failed to bring tokens below limit: "
|
||||||
|
f"{current_tokens} > {self.token_limit} tokens. "
|
||||||
|
f"Total messages: {len(compressed)}. "
|
||||||
|
f"Consider increasing token_limit or preserve_prefix_message_count."
|
||||||
|
)
|
||||||
|
|
||||||
if cur_token_cnt > 0 and available_token >= cur_token_cnt:
|
return compressed
|
||||||
suffix_messages = [messages[i]] + suffix_messages
|
|
||||||
available_token -= cur_token_cnt
|
|
||||||
elif available_token > 0:
|
|
||||||
# Truncate content to fit available tokens
|
|
||||||
truncated_message = self._truncate_message_content(
|
|
||||||
messages[i], available_token
|
|
||||||
)
|
|
||||||
suffix_messages = [truncated_message] + suffix_messages
|
|
||||||
return prefix_messages + suffix_messages
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
return prefix_messages + suffix_messages
|
|
||||||
|
|
||||||
def _truncate_message_content(
|
|
||||||
self, message: BaseMessage, max_tokens: int
|
|
||||||
) -> BaseMessage:
|
|
||||||
"""
|
|
||||||
Truncate message content while preserving all other attributes by copying the original message
|
|
||||||
and only modifying its content attribute.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The message to truncate
|
|
||||||
max_tokens: Maximum number of tokens to keep
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
New message instance with truncated content
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create a deep copy of the original message to preserve all attributes
|
|
||||||
truncated_message = copy.deepcopy(message)
|
|
||||||
|
|
||||||
# Truncate only the content attribute
|
|
||||||
truncated_message.content = message.content[:max_tokens]
|
|
||||||
|
|
||||||
return truncated_message
|
|
||||||
|
|
||||||
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
|
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class TestGetWebSearchTool:
|
|||||||
tool = get_web_search_tool(max_search_results=5)
|
tool = get_web_search_tool(max_search_results=5)
|
||||||
assert tool.name == "web_search"
|
assert tool.name == "web_search"
|
||||||
assert tool.max_results == 5
|
assert tool.max_results == 5
|
||||||
assert tool.include_raw_content is True
|
assert tool.include_raw_content is False
|
||||||
assert tool.include_images is True
|
assert tool.include_images is True
|
||||||
assert tool.include_image_descriptions is True
|
assert tool.include_image_descriptions is True
|
||||||
assert tool.include_answer is False
|
assert tool.include_answer is False
|
||||||
@@ -79,7 +79,7 @@ class TestGetWebSearchTool:
|
|||||||
"SEARCH_ENGINE": {
|
"SEARCH_ENGINE": {
|
||||||
"include_answer": True,
|
"include_answer": True,
|
||||||
"search_depth": "basic",
|
"search_depth": "basic",
|
||||||
"include_raw_content": False,
|
"include_raw_content": True,
|
||||||
"include_images": False,
|
"include_images": False,
|
||||||
"include_image_descriptions": True,
|
"include_image_descriptions": True,
|
||||||
"include_domains": ["example.com"],
|
"include_domains": ["example.com"],
|
||||||
@@ -91,7 +91,7 @@ class TestGetWebSearchTool:
|
|||||||
assert tool.max_results == 5
|
assert tool.max_results == 5
|
||||||
assert tool.include_answer is True
|
assert tool.include_answer is True
|
||||||
assert tool.search_depth == "basic"
|
assert tool.search_depth == "basic"
|
||||||
assert tool.include_raw_content is False
|
assert tool.include_raw_content is True
|
||||||
assert tool.include_images is False
|
assert tool.include_images is False
|
||||||
# include_image_descriptions should be False because include_images is False
|
# include_image_descriptions should be False because include_images is False
|
||||||
assert tool.include_image_descriptions is False
|
assert tool.include_image_descriptions is False
|
||||||
@@ -108,7 +108,7 @@ class TestGetWebSearchTool:
|
|||||||
assert tool.max_results == 10
|
assert tool.max_results == 10
|
||||||
assert tool.include_answer is False
|
assert tool.include_answer is False
|
||||||
assert tool.search_depth == "advanced"
|
assert tool.search_depth == "advanced"
|
||||||
assert tool.include_raw_content is True
|
assert tool.include_raw_content is False
|
||||||
assert tool.include_images is True
|
assert tool.include_images is True
|
||||||
assert tool.include_image_descriptions is True
|
assert tool.include_image_descriptions is True
|
||||||
assert tool.include_domains == []
|
assert tool.include_domains == []
|
||||||
@@ -143,7 +143,7 @@ class TestGetWebSearchTool:
|
|||||||
tool = get_web_search_tool(max_search_results=3)
|
tool = get_web_search_tool(max_search_results=3)
|
||||||
assert tool.include_answer is True
|
assert tool.include_answer is True
|
||||||
assert tool.search_depth == "advanced" # default
|
assert tool.search_depth == "advanced" # default
|
||||||
assert tool.include_raw_content is True # default
|
assert tool.include_raw_content is False # default
|
||||||
assert tool.include_domains == ["trusted.com"]
|
assert tool.include_domains == ["trusted.com"]
|
||||||
assert tool.exclude_domains == [] # default
|
assert tool.exclude_domains == [] # default
|
||||||
|
|
||||||
@@ -157,7 +157,7 @@ class TestGetWebSearchTool:
|
|||||||
assert tool.max_results == 5
|
assert tool.max_results == 5
|
||||||
assert tool.include_answer is False
|
assert tool.include_answer is False
|
||||||
assert tool.search_depth == "advanced"
|
assert tool.search_depth == "advanced"
|
||||||
assert tool.include_raw_content is True
|
assert tool.include_raw_content is False
|
||||||
assert tool.include_images is True
|
assert tool.include_images is True
|
||||||
|
|
||||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||||
@@ -184,7 +184,7 @@ class TestGetWebSearchTool:
|
|||||||
assert tool.max_results == 5
|
assert tool.max_results == 5
|
||||||
assert tool.include_answer is False
|
assert tool.include_answer is False
|
||||||
assert tool.search_depth == "advanced"
|
assert tool.search_depth == "advanced"
|
||||||
assert tool.include_raw_content is True
|
assert tool.include_raw_content is False
|
||||||
assert tool.include_images is True
|
assert tool.include_images is True
|
||||||
assert tool.include_domains == []
|
assert tool.include_domains == []
|
||||||
assert tool.exclude_domains == []
|
assert tool.exclude_domains == []
|
||||||
@@ -199,7 +199,7 @@ class TestGetWebSearchTool:
|
|||||||
assert tool.max_results == 5
|
assert tool.max_results == 5
|
||||||
assert tool.include_answer is False
|
assert tool.include_answer is False
|
||||||
assert tool.search_depth == "advanced"
|
assert tool.search_depth == "advanced"
|
||||||
assert tool.include_raw_content is True
|
assert tool.include_raw_content is False
|
||||||
assert tool.include_images is True
|
assert tool.include_images is True
|
||||||
|
|
||||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||||
@@ -210,7 +210,7 @@ class TestGetWebSearchTool:
|
|||||||
tool = get_web_search_tool(max_search_results=5)
|
tool = get_web_search_tool(max_search_results=5)
|
||||||
assert tool.include_answer is True
|
assert tool.include_answer is True
|
||||||
assert tool.search_depth == "advanced"
|
assert tool.search_depth == "advanced"
|
||||||
assert tool.include_raw_content is True
|
assert tool.include_raw_content is False
|
||||||
assert tool.include_images is True
|
assert tool.include_images is True
|
||||||
|
|
||||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||||
@@ -221,7 +221,7 @@ class TestGetWebSearchTool:
|
|||||||
tool = get_web_search_tool(max_search_results=5)
|
tool = get_web_search_tool(max_search_results=5)
|
||||||
assert tool.search_depth == "basic"
|
assert tool.search_depth == "basic"
|
||||||
assert tool.include_answer is False
|
assert tool.include_answer is False
|
||||||
assert tool.include_raw_content is True
|
assert tool.include_raw_content is False
|
||||||
assert tool.include_images is True
|
assert tool.include_images is True
|
||||||
|
|
||||||
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
@patch("src.tools.search.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY.value)
|
||||||
@@ -286,6 +286,6 @@ class TestGetWebSearchTool:
|
|||||||
tool.include_image_descriptions is False
|
tool.include_image_descriptions is False
|
||||||
) # should be False since include_images is False
|
) # should be False since include_images is False
|
||||||
assert tool.search_depth == "advanced" # default
|
assert tool.search_depth == "advanced" # default
|
||||||
assert tool.include_raw_content is True # default
|
assert tool.include_raw_content is False # default
|
||||||
assert tool.include_domains == [] # default
|
assert tool.include_domains == [] # default
|
||||||
assert tool.exclude_domains == [] # default
|
assert tool.exclude_domains == [] # default
|
||||||
|
|||||||
@@ -85,8 +85,8 @@ class TestContextManager:
|
|||||||
# Should return the same messages when not over limit
|
# Should return the same messages when not over limit
|
||||||
assert len(compressed["messages"]) == len(messages)
|
assert len(compressed["messages"]) == len(messages)
|
||||||
|
|
||||||
def test_compress_messages_with_system_message(self):
|
def test_compress_messages_with_tool_message(self):
|
||||||
"""Test compress_messages preserves system message"""
|
"""Test compress_messages preserves system message and compresses raw_content"""
|
||||||
# Create a context manager with limited token capacity
|
# Create a context manager with limited token capacity
|
||||||
limited_cm = ContextManager(token_limit=200)
|
limited_cm = ContextManager(token_limit=200)
|
||||||
|
|
||||||
@@ -94,15 +94,26 @@ class TestContextManager:
|
|||||||
SystemMessage(content="You are a helpful assistant."),
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
HumanMessage(content="Hello"),
|
HumanMessage(content="Hello"),
|
||||||
AIMessage(content="Hi there!"),
|
AIMessage(content="Hi there!"),
|
||||||
HumanMessage(
|
ToolMessage(
|
||||||
content="Can you tell me a very long story that would exceed token limits? "
|
name="web_search",
|
||||||
* 100
|
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
|
||||||
),
|
tool_call_id="test_search",
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
compressed = limited_cm.compress_messages({"messages": messages})
|
compressed = limited_cm.compress_messages({"messages": messages})
|
||||||
# Should preserve system message and some recent messages
|
# Should preserve system message and some recent messages
|
||||||
assert len(compressed["messages"]) == 1
|
assert len(compressed["messages"]) == 4
|
||||||
|
|
||||||
|
# Verify raw_content was compressed to 1024 characters
|
||||||
|
import json
|
||||||
|
for msg in compressed["messages"]:
|
||||||
|
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
|
||||||
|
content_data = json.loads(msg.content)
|
||||||
|
if isinstance(content_data, list):
|
||||||
|
for item in content_data:
|
||||||
|
if isinstance(item, dict) and "raw_content" in item:
|
||||||
|
assert len(item["raw_content"]) == 1024
|
||||||
|
|
||||||
def test_compress_messages_with_preserve_prefix_message(self):
|
def test_compress_messages_with_preserve_prefix_message(self):
|
||||||
"""Test compress_messages when no system message is present"""
|
"""Test compress_messages when no system message is present"""
|
||||||
@@ -201,9 +212,24 @@ class TestContextManager:
|
|||||||
HumanMessage(
|
HumanMessage(
|
||||||
content="Can you tell me a very long story that would exceed token limits? " * 100
|
content="Can you tell me a very long story that would exceed token limits? " * 100
|
||||||
),
|
),
|
||||||
|
ToolMessage(
|
||||||
|
name="web_search",
|
||||||
|
content='[{"title": "Test Result", "url": "https://example.com", "raw_content": "' + ("This is a test content that should be compressed if it exceeds 1024 characters. " * 2000) + '"}]',
|
||||||
|
tool_call_id="test_search",
|
||||||
|
)
|
||||||
]
|
]
|
||||||
compressed = limited_cm.compress_messages({"messages": messages}, runtime=object())
|
compressed = limited_cm.compress_messages({"messages": messages}, runtime=object())
|
||||||
assert isinstance(compressed, dict)
|
assert isinstance(compressed, dict)
|
||||||
assert "messages" in compressed
|
assert "messages" in compressed
|
||||||
# Should preserve only what fits; with this setup we expect heavy compression
|
# Should preserve only what fits; with this setup we expect heavy compression
|
||||||
assert len(compressed["messages"]) == 1
|
assert len(compressed["messages"]) == 5
|
||||||
|
|
||||||
|
# Verify raw_content was compressed to 1024 characters
|
||||||
|
import json
|
||||||
|
for msg in compressed["messages"]:
|
||||||
|
if isinstance(msg, ToolMessage) and getattr(msg, "name", None) == "web_search":
|
||||||
|
content_data = json.loads(msg.content)
|
||||||
|
if isinstance(content_data, list):
|
||||||
|
for item in content_data:
|
||||||
|
if isinstance(item, dict) and "raw_content" in item:
|
||||||
|
assert len(item["raw_content"]) == 1024
|
||||||
|
|||||||
Reference in New Issue
Block a user