mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-26 15:24:48 +08:00
feat: add context compress (#590)
* feat:Add context compress * feat: Add unit test * feat: add unit test for context manager * feat: add postprocessor param && code format * feat: add configuration guide * fix: fix the configuration_guide * fix: fix the unit test * fix: fix the default value * feat: add test and log for context_manager
This commit is contained in:
262
tests/unit/tools/test_search_postprocessor.py
Normal file
262
tests/unit/tools/test_search_postprocessor.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import pytest
|
||||
from src.tools.search_postprocessor import SearchResultPostProcessor
|
||||
|
||||
|
||||
class TestSearchResultPostProcessor:
|
||||
"""Test cases for SearchResultPostProcessor"""
|
||||
|
||||
@pytest.fixture
|
||||
def post_processor(self):
|
||||
"""Create a SearchResultPostProcessor instance for testing"""
|
||||
return SearchResultPostProcessor(
|
||||
min_score_threshold=0.5, max_content_length_per_page=100
|
||||
)
|
||||
|
||||
def test_process_results_empty_input(self, post_processor):
|
||||
"""Test processing empty results"""
|
||||
results = []
|
||||
processed = post_processor.process_results(results)
|
||||
assert processed == []
|
||||
|
||||
def test_process_results_with_valid_page_results(self, post_processor):
|
||||
"""Test processing valid page results"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Test Page",
|
||||
"url": "https://example.com",
|
||||
"content": "Test content",
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["title"] == "Test Page"
|
||||
assert processed[0]["url"] == "https://example.com"
|
||||
assert processed[0]["content"] == "Test content"
|
||||
assert processed[0]["score"] == 0.8
|
||||
|
||||
def test_process_results_filter_low_score(self, post_processor):
|
||||
"""Test filtering out low score results"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Low Score Page",
|
||||
"url": "https://example.com/low",
|
||||
"content": "Low score content",
|
||||
"score": 0.3, # Below threshold of 0.5
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "High Score Page",
|
||||
"url": "https://example.com/high",
|
||||
"content": "High score content",
|
||||
"score": 0.9,
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["title"] == "High Score Page"
|
||||
|
||||
def test_process_results_remove_duplicates(self, post_processor):
|
||||
"""Test removing duplicate URLs"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Page 1",
|
||||
"url": "https://example.com",
|
||||
"content": "Content 1",
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Page 2",
|
||||
"url": "https://example.com", # Duplicate URL
|
||||
"content": "Content 2",
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["title"] == "Page 1" # First one should be kept
|
||||
|
||||
def test_process_results_sort_by_score(self, post_processor):
|
||||
"""Test sorting results by score in descending order"""
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Low Score",
|
||||
"url": "https://example.com/low",
|
||||
"content": "Low score content",
|
||||
"score": 0.3,
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "High Score",
|
||||
"url": "https://example.com/high",
|
||||
"content": "High score content",
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Medium Score",
|
||||
"url": "https://example.com/medium",
|
||||
"content": "Medium score content",
|
||||
"score": 0.6,
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 2 # Low score filtered out
|
||||
# Should be sorted by score descending
|
||||
assert processed[0]["title"] == "High Score"
|
||||
assert processed[1]["title"] == "Medium Score"
|
||||
|
||||
def test_process_results_truncate_long_content(self, post_processor):
|
||||
"""Test truncating long content"""
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["content"]) == 103 # 100 + "..."
|
||||
assert processed[0]["content"].endswith("...")
|
||||
|
||||
def test_process_results_remove_base64_images(self, post_processor):
|
||||
"""Test removing base64 images from content"""
|
||||
content_with_base64 = (
|
||||
"Content with image "
|
||||
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
|
||||
)
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Page with Base64",
|
||||
"url": "https://example.com",
|
||||
"content": content_with_base64,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["content"] == "Content with image "
|
||||
|
||||
def test_process_results_with_image_type(self, post_processor):
|
||||
"""Test processing image type results"""
|
||||
results = [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "https://example.com/image.jpg",
|
||||
"image_description": "Test image",
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["type"] == "image"
|
||||
assert processed[0]["image_url"] == "https://example.com/image.jpg"
|
||||
assert processed[0]["image_description"] == "Test image"
|
||||
|
||||
def test_process_results_filter_base64_image_urls(self, post_processor):
|
||||
"""Test filtering out image results with base64 URLs"""
|
||||
results = [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
|
||||
"image_description": "Base64 image",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "https://example.com/image.jpg",
|
||||
"image_description": "Regular image",
|
||||
},
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["image_url"] == "https://example.com/image.jpg"
|
||||
|
||||
def test_process_results_truncate_long_image_description(self, post_processor):
|
||||
"""Test truncating long image descriptions"""
|
||||
long_description = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": "https://example.com/image.jpg",
|
||||
"image_description": long_description,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["image_description"]) == 103 # 100 + "..."
|
||||
assert processed[0]["image_description"].endswith("...")
|
||||
|
||||
def test_process_results_other_types_passthrough(self, post_processor):
|
||||
"""Test that other result types pass through unchanged"""
|
||||
results = [
|
||||
{
|
||||
"type": "video",
|
||||
"title": "Test Video",
|
||||
"url": "https://example.com/video.mp4",
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert processed[0]["type"] == "video"
|
||||
assert processed[0]["title"] == "Test Video"
|
||||
|
||||
def test_process_results_truncate_long_content_with_no_config(self):
|
||||
"""Test truncating long content"""
|
||||
post_processor = SearchResultPostProcessor(None, None)
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["content"]) == len("A" * 150)
|
||||
|
||||
def test_process_results_truncate_long_content_with_max_content_length_config(self):
|
||||
"""Test truncating long content"""
|
||||
post_processor = SearchResultPostProcessor(None, 100)
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 1
|
||||
assert len(processed[0]["content"]) == 103
|
||||
assert processed[0]["content"].endswith("...")
|
||||
|
||||
def test_process_results_truncate_long_content_with_min_score_config(self):
|
||||
"""Test truncating long content"""
|
||||
post_processor = SearchResultPostProcessor(0.8, None)
|
||||
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||
results = [
|
||||
{
|
||||
"type": "page",
|
||||
"title": "Long Content Page",
|
||||
"url": "https://example.com",
|
||||
"content": long_content,
|
||||
"score": 0.3,
|
||||
}
|
||||
]
|
||||
processed = post_processor.process_results(results)
|
||||
assert len(processed) == 0
|
||||
183
tests/unit/utils/test_context_manager.py
Normal file
183
tests/unit/utils/test_context_manager.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||
from src.utils.context_manager import ContextManager
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test cases for ContextManager"""
|
||||
|
||||
def test_count_tokens_with_empty_messages(self):
|
||||
"""Test counting tokens with empty message list"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = []
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
assert token_count == 0
|
||||
|
||||
def test_count_tokens_with_system_message(self):
|
||||
"""Test counting tokens with system message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
# System message has 28 characters, should be around 8 tokens (28/4 * 1.1)
|
||||
assert token_count > 7
|
||||
|
||||
def test_count_tokens_with_human_message(self):
|
||||
"""Test counting tokens with human message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [HumanMessage(content="你好,这是一个测试消息。")]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
assert token_count > 12
|
||||
|
||||
def test_count_tokens_with_ai_message(self):
|
||||
"""Test counting tokens with AI message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [AIMessage(content="I'm doing well, thank you for asking!")]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
assert token_count >= 10
|
||||
|
||||
def test_count_tokens_with_tool_message(self):
|
||||
"""Test counting tokens with tool message"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [
|
||||
ToolMessage(content="Tool execution result data here", tool_call_id="test")
|
||||
]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
# Tool message has about 32 characters, should be around 10 tokens (32/4 * 1.3)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_tokens_with_multiple_messages(self):
|
||||
"""Test counting tokens with multiple messages"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello, how are you?"),
|
||||
AIMessage(content="I'm doing well, thank you for asking!"),
|
||||
]
|
||||
token_count = context_manager.count_tokens(messages)
|
||||
# Should be sum of all individual message tokens
|
||||
assert token_count > 0
|
||||
|
||||
def test_is_over_limit_when_under_limit(self):
|
||||
"""Test is_over_limit when messages are under token limit"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
short_messages = [HumanMessage(content="Short message")]
|
||||
is_over = context_manager.is_over_limit(short_messages)
|
||||
assert is_over is False
|
||||
|
||||
def test_is_over_limit_when_over_limit(self):
|
||||
"""Test is_over_limit when messages exceed token limit"""
|
||||
# Create a context manager with a very low limit
|
||||
low_limit_cm = ContextManager(token_limit=5)
|
||||
long_messages = [
|
||||
HumanMessage(
|
||||
content="This is a very long message that should exceed the limit"
|
||||
)
|
||||
]
|
||||
is_over = low_limit_cm.is_over_limit(long_messages)
|
||||
assert is_over is True
|
||||
|
||||
def test_compress_messages_when_not_over_limit(self):
|
||||
"""Test compress_messages when messages are not over limit"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [HumanMessage(content="Short message")]
|
||||
compressed = context_manager.compress_messages({"messages": messages})
|
||||
# Should return the same messages when not over limit
|
||||
assert len(compressed["messages"]) == len(messages)
|
||||
|
||||
def test_compress_messages_with_system_message(self):
|
||||
"""Test compress_messages preserves system message"""
|
||||
# Create a context manager with limited token capacity
|
||||
limited_cm = ContextManager(token_limit=200)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(
|
||||
content="Can you tell me a very long story that would exceed token limits? "
|
||||
* 100
|
||||
),
|
||||
]
|
||||
|
||||
compressed = limited_cm.compress_messages({"messages": messages})
|
||||
# Should preserve system message and some recent messages
|
||||
assert len(compressed["messages"]) == 1
|
||||
|
||||
def test_compress_messages_with_preserve_prefix_message(self):
|
||||
"""Test compress_messages when no system message is present"""
|
||||
# Create a context manager with limited token capacity
|
||||
limited_cm = ContextManager(token_limit=100, preserve_prefix_message_count=2)
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(
|
||||
content="Can you tell me a very long story that would exceed token limits? "
|
||||
* 10
|
||||
),
|
||||
]
|
||||
|
||||
compressed = limited_cm.compress_messages({"messages": messages})
|
||||
# Should keep only the most recent messages that fit
|
||||
assert len(compressed["messages"]) == 3
|
||||
|
||||
def test_compress_messages_without_config(self):
|
||||
"""Test compress_messages preserves system message"""
|
||||
# Create a context manager with limited token capacity
|
||||
limited_cm = ContextManager(None)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(
|
||||
content="Can you tell me a very long story that would exceed token limits? "
|
||||
* 100
|
||||
),
|
||||
]
|
||||
|
||||
compressed = limited_cm.compress_messages({"messages": messages})
|
||||
# return the original messages
|
||||
assert len(compressed["messages"]) == 4
|
||||
|
||||
|
||||
def test_count_message_tokens_with_additional_kwargs(self):
|
||||
"""Test counting tokens for messages with additional kwargs"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
message = ToolMessage(
|
||||
content="Tool result",
|
||||
tool_call_id="test",
|
||||
additional_kwargs={"tool_calls": [{"name": "test_function"}]},
|
||||
)
|
||||
token_count = context_manager._count_message_tokens(message)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_message_tokens_minimum_one_token(self):
|
||||
"""Test that message token count is at least 1"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
message = HumanMessage(content="") # Empty content
|
||||
token_count = context_manager._count_message_tokens(message)
|
||||
assert token_count == 1 # Should be at least 1
|
||||
|
||||
def test_count_text_tokens_english_only(self):
|
||||
"""Test counting tokens for English text"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
# 16 English characters should result in 4 tokens (16/4)
|
||||
text = "This is a test."
|
||||
token_count = context_manager._count_text_tokens(text)
|
||||
assert token_count > 0
|
||||
|
||||
def test_count_text_tokens_chinese_only(self):
|
||||
"""Test counting tokens for Chinese text"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
# 8 Chinese characters should result in 8 tokens (1:1 ratio)
|
||||
text = "这是一个测试文本"
|
||||
token_count = context_manager._count_text_tokens(text)
|
||||
assert token_count == 8
|
||||
|
||||
def test_count_text_tokens_mixed_content(self):
|
||||
"""Test counting tokens for mixed English and Chinese text"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
text = "Hello world 这是一些中文"
|
||||
token_count = context_manager._count_text_tokens(text)
|
||||
assert token_count > 6
|
||||
Reference in New Issue
Block a user