feat: add context compress (#590)

* feat:Add context compress

* feat: Add unit test

* feat: add unit test for context manager

* feat: add postprocessor param && code format

* feat: add configuration guide

* fix: fix the configuration_guide

* fix: fix the unit test

* fix: fix the default value

* feat: add test and log for context_manager
This commit is contained in:
Fancy-hjyp
2025-09-27 06:42:22 -07:00
committed by GitHub
parent c214999606
commit 5f4eb38fdb
9 changed files with 1032 additions and 7 deletions

View File

@@ -180,6 +180,20 @@ BASIC_MODEL:
api_key: $AZURE_OPENAI_API_KEY api_key: $AZURE_OPENAI_API_KEY
``` ```
### How to configure context length for different models
Different models have different context length limitations. DeerFlow provides a method to control the context length between different models. You can configure the context length between different models in the `conf.yaml` file. For example:
```yaml
BASIC_MODEL:
base_url: https://ark.cn-beijing.volces.com/api/v3
model: "doubao-1-5-pro-32k-250115"
api_key: ""
token_limit: 128000
```
This means that the context length limit using this model is 128k.
The context management doesn't work if the token_limit is not set.
## About Search Engine ## About Search Engine
### How to control search domains for Tavily? ### How to control search domains for Tavily?
@@ -210,6 +224,28 @@ SEARCH_ENGINE:
include_raw_content: false include_raw_content: false
``` ```
### How to post-process Tavily search results
DeerFlow can post-process Tavily search results:
* Remove duplicate content
* Filter low-quality content: Filter out results with low relevance scores
* Clear base64 encoded images
* Length truncation: Truncate each search result according to the user-configured length
The filtering of low-quality content and length truncation depend on user configuration, providing two configurable parameters:
* min_score_threshold: Minimum relevance score threshold, search results below this threshold will be filtered. If not set, no filtering will be performed;
* max_content_length_per_page: Maximum length limit for each search result content, parts exceeding this length will be truncated. If not set, no truncation will be performed;
These two parameters can be configured in `conf.yaml` as shown below:
```yaml
SEARCH_ENGINE:
engine: tavily
include_images: true
min_score_threshold: 0.4
max_content_length_per_page: 5000
```
That's meaning that the search results will be filtered based on the minimum relevance score threshold and truncated to the maximum length limit for each search result content.
## RAG (Retrieval-Augmented Generation) Configuration ## RAG (Retrieval-Augmented Generation) Configuration
DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables. DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables.
@@ -244,4 +280,4 @@ MILVUS_EMBEDDING_PROVIDER=openai
MILVUS_EMBEDDING_BASE_URL= MILVUS_EMBEDDING_BASE_URL=
MILVUS_EMBEDDING_MODEL= MILVUS_EMBEDDING_MODEL=
MILVUS_EMBEDDING_API_KEY= MILVUS_EMBEDDING_API_KEY=
``` ```

View File

@@ -9,11 +9,18 @@ from src.prompts import apply_prompt_template
# Create agents using configured LLM types # Create agents using configured LLM types
def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template: str): def create_agent(
agent_name: str,
agent_type: str,
tools: list,
prompt_template: str,
pre_model_hook: callable = None,
):
"""Factory function to create agents with consistent configuration.""" """Factory function to create agents with consistent configuration."""
return create_react_agent( return create_react_agent(
name=agent_name, name=agent_name,
model=get_llm_by_type(AGENT_LLM_MAP[agent_type]), model=get_llm_by_type(AGENT_LLM_MAP[agent_type]),
tools=tools, tools=tools,
prompt=lambda state: apply_prompt_template(prompt_template, state), prompt=lambda state: apply_prompt_template(prompt_template, state),
pre_model_hook=pre_model_hook,
) )

View File

@@ -6,16 +6,17 @@ import logging
import os import os
from typing import Annotated, Literal from typing import Annotated, Literal
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool from langchain_core.tools import tool
from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.types import Command, interrupt from langgraph.types import Command, interrupt
from functools import partial
from src.agents import create_agent from src.agents import create_agent
from src.config.agents import AGENT_LLM_MAP from src.config.agents import AGENT_LLM_MAP
from src.config.configuration import Configuration from src.config.configuration import Configuration
from src.llms.llm import get_llm_by_type from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
from src.prompts.planner_model import Plan from src.prompts.planner_model import Plan
from src.prompts.template import apply_prompt_template from src.prompts.template import apply_prompt_template
from src.tools import ( from src.tools import (
@@ -26,6 +27,7 @@ from src.tools import (
) )
from src.tools.search import LoggedTavilySearch from src.tools.search import LoggedTavilySearch
from src.utils.json_utils import repair_json_output from src.utils.json_utils import repair_json_output
from src.utils.context_manager import ContextManager
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
from .types import State from .types import State
@@ -283,13 +285,22 @@ def reporter_node(state: State, config: RunnableConfig):
) )
) )
observation_messages = []
for observation in observations: for observation in observations:
invoke_messages.append( observation_messages.append(
HumanMessage( HumanMessage(
content=f"Below are some observations for the research task:\n\n{observation}", content=f"Below are some observations for the research task:\n\n{observation}",
name="observation", name="observation",
) )
) )
# Context compression
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP["reporter"])
compressed_state = ContextManager(llm_token_limit).compress_messages(
{"messages": observation_messages}
)
invoke_messages += compressed_state.get("messages", [])
logger.debug(f"Current invoke messages: {invoke_messages}") logger.debug(f"Current invoke messages: {invoke_messages}")
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages) response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
response_content = response.content response_content = response.content
@@ -469,11 +480,20 @@ async def _setup_and_execute_agent_step(
f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}" f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}"
) )
loaded_tools.append(tool) loaded_tools.append(tool)
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, loaded_tools, agent_type, pre_model_hook
)
return await _execute_agent_step(state, agent, agent_type) return await _execute_agent_step(state, agent, agent_type)
else: else:
# Use default tools if no MCP servers are configured # Use default tools if no MCP servers are configured
agent = create_agent(agent_type, agent_type, default_tools, agent_type) llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, default_tools, agent_type, pre_model_hook
)
return await _execute_agent_step(state, agent, agent_type) return await _execute_agent_step(state, agent, agent_type)

View File

@@ -67,6 +67,10 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatMod
# Merge configurations, with environment variables taking precedence # Merge configurations, with environment variables taking precedence
merged_conf = {**llm_conf, **env_conf} merged_conf = {**llm_conf, **env_conf}
# Remove unnecessary parameters when initializing the client
if "token_limit" in merged_conf:
merged_conf.pop("token_limit")
if not merged_conf: if not merged_conf:
raise ValueError(f"No configuration found for LLM type: {llm_type}") raise ValueError(f"No configuration found for LLM type: {llm_type}")
@@ -174,6 +178,25 @@ def get_configured_llm_models() -> dict[str, list[str]]:
return {} return {}
def get_llm_token_limit_by_type(llm_type: str) -> int:
"""
Get the maximum token limit for a given LLM type.
Args:
llm_type (str): The type of LLM.
Returns:
int: The maximum token limit for the specified LLM type.
"""
llm_type_config_keys = _get_llm_type_config_keys()
config_key = llm_type_config_keys.get(llm_type)
conf = load_yaml_config(_get_config_file_path())
llm_max_token = conf.get(config_key, {}).get("token_limit")
return llm_max_token
# In the future, we will use reasoning_llm and vl_llm for different purposes # In the future, we will use reasoning_llm and vl_llm for different purposes
# reasoning_llm = get_llm_by_type("reasoning") # reasoning_llm = get_llm_by_type("reasoning")
# vl_llm = get_llm_by_type("vision") # vl_llm = get_llm_by_type("vision")

View File

@@ -0,0 +1,212 @@
# src/tools/search_postprocessor.py
import re
import base64
import logging
from typing import List, Dict, Any
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
class SearchResultPostProcessor:
"""Search result post-processor"""
base64_pattern = r"data:image/[^;]+;base64,[a-zA-Z0-9+/=]+"
def __init__(self, min_score_threshold: float, max_content_length_per_page: int):
"""
Initialize the post-processor
Args:
min_score_threshold: Minimum relevance score threshold
max_content_length_per_page: Maximum content length
"""
self.min_score_threshold = min_score_threshold
self.max_content_length_per_page = max_content_length_per_page
def process_results(self, results: List[Dict]) -> List[Dict]:
"""
Process search results
Args:
results: Original search result list
Returns:
Processed result list
"""
if not results:
return []
# Combined processing in a single loop for efficiency
cleaned_results = []
seen_urls = set()
for result in results:
# 1. Remove duplicates
cleaned_result = self._remove_duplicates(result, seen_urls)
if not cleaned_result:
continue
# 2. Filter low quality results
if (
"page" == cleaned_result.get("type")
and self.min_score_threshold
and self.min_score_threshold > 0
and cleaned_result.get("score", 0) < self.min_score_threshold
):
continue
# 3. Clean base64 images from content
cleaned_result = self._remove_base64_images(cleaned_result)
if not cleaned_result:
continue
# 4. When max_content_length_per_page is set, truncate long content
if (
self.max_content_length_per_page
and self.max_content_length_per_page > 0
):
cleaned_result = self._truncate_long_content(cleaned_result)
if cleaned_result:
cleaned_results.append(cleaned_result)
# 5. Sort (by score descending)
sorted_results = sorted(
cleaned_results, key=lambda x: x.get("score", 0), reverse=True
)
logger.info(
f"Search result post-processing: {len(results)} -> {len(sorted_results)}"
)
return sorted_results
def _remove_base64_images(self, result: Dict) -> Dict:
"""Remove base64 encoded images from content"""
if "page" == result.get("type"):
cleaned_result = self.processPage(result)
elif "image" == result.get("type"):
cleaned_result = self.processImage(result)
else:
# For other types, keep as is
cleaned_result = result.copy()
return cleaned_result
def processPage(self, result: Dict) -> Dict:
"""Process page type result"""
# Clean base64 images from content
cleaned_result = result.copy()
if "content" in result:
original_content = result["content"]
cleaned_content = re.sub(self.base64_pattern, " ", original_content)
cleaned_result["content"] = cleaned_content
# Log if significant content was removed
if len(cleaned_content) < len(original_content) * 0.8:
logger.debug(
f"Removed base64 images from search content: {result.get('url', 'unknown')}"
)
# Clean base64 images from raw content
if "raw_content" in cleaned_result:
original_raw_content = cleaned_result["raw_content"]
cleaned_raw_content = re.sub(self.base64_pattern, " ", original_raw_content)
cleaned_result["raw_content"] = cleaned_raw_content
# Log if significant content was removed
if len(cleaned_raw_content) < len(original_raw_content) * 0.8:
logger.debug(
f"Removed base64 images from search raw content: {result.get('url', 'unknown')}"
)
return cleaned_result
def processImage(self, result: Dict) -> Dict:
"""Process image type result - clean up base64 data and long fields"""
cleaned_result = result.copy()
# Remove base64 encoded data from image_url if present
if "image_url" in cleaned_result and isinstance(
cleaned_result["image_url"], str
):
# Check if image_url contains base64 data
if "data:image" in cleaned_result["image_url"]:
original_image_url = cleaned_result["image_url"]
cleaned_image_url = re.sub(self.base64_pattern, " ", original_image_url)
if len(cleaned_image_url) == 0 or not cleaned_image_url.startswith(
"http"
):
logger.debug(
f"Removed base64 data from image_url and the cleaned_image_url is empty or not start with http, origin image_url: {result.get('image_url', 'unknown')}"
)
return {}
cleaned_result["image_url"] = cleaned_image_url
logger.debug(
f"Removed base64 data from image_url: {result.get('image_url', 'unknown')}"
)
# Truncate very long image descriptions
if "image_description" in cleaned_result and isinstance(
cleaned_result["image_description"], str
):
if (
self.max_content_length_per_page
and len(cleaned_result["image_description"])
> self.max_content_length_per_page
):
cleaned_result["image_description"] = (
cleaned_result["image_description"][
: self.max_content_length_per_page
]
+ "..."
)
logger.info(
f"Truncated long image description from search result: {result.get('image_url', 'unknown')}"
)
return cleaned_result
def _truncate_long_content(self, result: Dict) -> Dict:
"""Truncate long content"""
truncated_result = result.copy()
# Truncate content length
if "content" in truncated_result:
content = truncated_result["content"]
if len(content) > self.max_content_length_per_page:
truncated_result["content"] = (
content[: self.max_content_length_per_page] + "..."
)
logger.info(
f"Truncated long content from search result: {result.get('url', 'unknown')}"
)
# Truncate raw content length (can be slightly longer)
if "raw_content" in truncated_result:
raw_content = truncated_result["raw_content"]
if len(raw_content) > self.max_content_length_per_page * 2:
truncated_result["raw_content"] = (
raw_content[: self.max_content_length_per_page * 2] + "..."
)
logger.info(
f"Truncated long raw content from search result: {result.get('url', 'unknown')}"
)
return truncated_result
def _remove_duplicates(self, result: Dict, seen_urls: set) -> Dict:
"""Remove duplicate results"""
url = result.get("url", result.get("image_url", ""))
if url and url not in seen_urls:
seen_urls.add(url)
return result.copy() # Return a copy to avoid modifying original
elif not url:
# Keep results with empty URLs
return result.copy() # Return a copy to avoid modifying original
return {} # Return empty dict for duplicates

View File

@@ -11,6 +11,14 @@ from langchain_tavily._utilities import TAVILY_API_URL
from langchain_tavily.tavily_search import ( from langchain_tavily.tavily_search import (
TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper, TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper,
) )
from src.tools.search_postprocessor import SearchResultPostProcessor
from src.config import load_yaml_config
def get_search_config():
config = load_yaml_config("conf.yaml")
search_config = config.get("SEARCH_ENGINE", {})
return search_config
class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper): class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
@@ -110,4 +118,13 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
"image_description": image["description"], "image_description": image["description"],
} }
clean_results.append(clean_result) clean_results.append(clean_result)
search_config = get_search_config()
clean_results = SearchResultPostProcessor(
min_score_threshold=search_config.get("min_score_threshold"),
max_content_length_per_page=search_config.get(
"max_content_length_per_page"
),
).process_results(clean_results)
return clean_results return clean_results

View File

@@ -0,0 +1,265 @@
# src/utils/token_manager.py
from typing import List
from langchain_core.messages import (
BaseMessage,
HumanMessage,
AIMessage,
ToolMessage,
SystemMessage,
)
import logging
import copy
from src.config import load_yaml_config
logger = logging.getLogger(__name__)
def get_search_config():
config = load_yaml_config("conf.yaml")
search_config = config.get("MODEL_TOKEN_LIMITS", {})
return search_config
class ContextManager:
"""Context manager and compression class"""
def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0):
"""
Initialize ContextManager
Args:
token_limit: Maximum token limit
preserve_prefix_message_count: Number of messages to preserve at the beginning of the context
"""
self.token_limit = token_limit
self.preserve_prefix_message_count = preserve_prefix_message_count
def count_tokens(self, messages: List[BaseMessage]) -> int:
"""
Count tokens in message list
Args:
messages: List of messages
Returns:
Number of tokens
"""
total_tokens = 0
for message in messages:
total_tokens += self._count_message_tokens(message)
return total_tokens
def _count_message_tokens(self, message: BaseMessage) -> int:
"""
Count tokens in a single message
Args:
message: Message object
Returns:
Number of tokens
"""
# Estimate token count based on character length (different calculation for English and non-English)
token_count = 0
# Count tokens in content field
if hasattr(message, "content") and message.content:
# Handle different content types
if isinstance(message.content, str):
token_count += self._count_text_tokens(message.content)
# Count role-related tokens
if hasattr(message, "type"):
token_count += self._count_text_tokens(message.type)
# Special handling for different message types
if isinstance(message, SystemMessage):
# System messages are usually short but important, slightly increase estimate
token_count = int(token_count * 1.1)
elif isinstance(message, HumanMessage):
# Human messages use normal estimation
pass
elif isinstance(message, AIMessage):
# AI messages may contain reasoning content, slightly increase estimate
token_count = int(token_count * 1.2)
elif isinstance(message, ToolMessage):
# Tool messages may contain large amounts of structured data, increase estimate
token_count = int(token_count * 1.3)
# Process additional information in additional_kwargs
if hasattr(message, "additional_kwargs") and message.additional_kwargs:
# Simple estimation of extra field tokens
extra_str = str(message.additional_kwargs)
token_count += self._count_text_tokens(extra_str)
# If there are tool_calls, add estimation
if "tool_calls" in message.additional_kwargs:
token_count += 50 # Add estimation for function call information
# Ensure at least 1 token
return max(1, token_count)
def _count_text_tokens(self, text: str) -> int:
"""
Count tokens in text with different calculations for English and non-English characters.
English characters: 4 characters ≈ 1 token
Non-English characters (e.g., Chinese): 1 character ≈ 1 token
Args:
text: Text to count tokens for
Returns:
Number of tokens
"""
if not text:
return 0
english_chars = 0
non_english_chars = 0
for char in text:
# Check if character is ASCII (English letters, digits, punctuation)
if ord(char) < 128:
english_chars += 1
else:
non_english_chars += 1
# Calculate tokens: English at 4 chars/token, others at 1 char/token
english_tokens = english_chars // 4
non_english_tokens = non_english_chars
return english_tokens + non_english_tokens
def is_over_limit(self, messages: List[BaseMessage]) -> bool:
"""
Check if messages exceed token limit
Args:
messages: List of messages
Returns:
Whether limit is exceeded
"""
return self.count_tokens(messages) > self.token_limit
def compress_messages(self, state: dict) -> List[BaseMessage]:
"""
Compress messages to fit within token limit
Args:
state: state with original messages
Returns:
Compressed state with compressed messages
"""
# If not set token_limit, return original state
if self.token_limit is None:
logger.info("No token_limit set, the context management doesn't work.")
return state
if not isinstance(state, dict) or "messages" not in state:
logger.warning("No messages found in state")
return state
messages = state["messages"]
if not self.is_over_limit(messages):
return state
# 2. Compress messages
compressed_messages = self._compress_messages(messages)
logger.info(
f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens"
)
state["messages"] = compressed_messages
return state
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
"""
Compress compressible messages
Args:
messages: List of messages to compress
Returns:
Compressed message list
"""
available_token = self.token_limit
prefix_messages = []
# 1. Preserve head messages of specified length to retain system prompts and user input
for i in range(min(self.preserve_prefix_message_count, len(messages))):
cur_token_cnt = self._count_message_tokens(messages[i])
if available_token > 0 and available_token >= cur_token_cnt:
prefix_messages.append(messages[i])
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
messages[i], available_token
)
prefix_messages.append(truncated_message)
return prefix_messages
else:
break
# 2. Compress subsequent messages from the tail, some messages may be discarded
messages = messages[len(prefix_messages) :]
suffix_messages = []
for i in range(len(messages) - 1, -1, -1):
cur_token_cnt = self._count_message_tokens(messages[i])
if cur_token_cnt > 0 and available_token >= cur_token_cnt:
suffix_messages = [messages[i]] + suffix_messages
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
messages[i], available_token
)
suffix_messages = [truncated_message] + suffix_messages
return prefix_messages + suffix_messages
else:
break
return prefix_messages + suffix_messages
def _truncate_message_content(
self, message: BaseMessage, max_tokens: int
) -> BaseMessage:
"""
Truncate message content while preserving all other attributes by copying the original message
and only modifying its content attribute.
Args:
message: The message to truncate
max_tokens: Maximum number of tokens to keep
Returns:
New message instance with truncated content
"""
# Create a deep copy of the original message to preserve all attributes
truncated_message = copy.deepcopy(message)
# Truncate only the content attribute
truncated_message.content = message.content[:max_tokens]
return truncated_message
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Create summary for messages
Args:
messages: Messages to summarize
Returns:
Summary message
"""
# TODO: summary implementation
pass

View File

@@ -0,0 +1,262 @@
import pytest
from src.tools.search_postprocessor import SearchResultPostProcessor
class TestSearchResultPostProcessor:
"""Test cases for SearchResultPostProcessor"""
@pytest.fixture
def post_processor(self):
"""Create a SearchResultPostProcessor instance for testing"""
return SearchResultPostProcessor(
min_score_threshold=0.5, max_content_length_per_page=100
)
def test_process_results_empty_input(self, post_processor):
"""Test processing empty results"""
results = []
processed = post_processor.process_results(results)
assert processed == []
def test_process_results_with_valid_page_results(self, post_processor):
"""Test processing valid page results"""
results = [
{
"type": "page",
"title": "Test Page",
"url": "https://example.com",
"content": "Test content",
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["title"] == "Test Page"
assert processed[0]["url"] == "https://example.com"
assert processed[0]["content"] == "Test content"
assert processed[0]["score"] == 0.8
def test_process_results_filter_low_score(self, post_processor):
"""Test filtering out low score results"""
results = [
{
"type": "page",
"title": "Low Score Page",
"url": "https://example.com/low",
"content": "Low score content",
"score": 0.3, # Below threshold of 0.5
},
{
"type": "page",
"title": "High Score Page",
"url": "https://example.com/high",
"content": "High score content",
"score": 0.9,
},
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["title"] == "High Score Page"
def test_process_results_remove_duplicates(self, post_processor):
"""Test removing duplicate URLs"""
results = [
{
"type": "page",
"title": "Page 1",
"url": "https://example.com",
"content": "Content 1",
"score": 0.8,
},
{
"type": "page",
"title": "Page 2",
"url": "https://example.com", # Duplicate URL
"content": "Content 2",
"score": 0.7,
},
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["title"] == "Page 1" # First one should be kept
def test_process_results_sort_by_score(self, post_processor):
"""Test sorting results by score in descending order"""
results = [
{
"type": "page",
"title": "Low Score",
"url": "https://example.com/low",
"content": "Low score content",
"score": 0.3,
},
{
"type": "page",
"title": "High Score",
"url": "https://example.com/high",
"content": "High score content",
"score": 0.9,
},
{
"type": "page",
"title": "Medium Score",
"url": "https://example.com/medium",
"content": "Medium score content",
"score": 0.6,
},
]
processed = post_processor.process_results(results)
assert len(processed) == 2 # Low score filtered out
# Should be sorted by score descending
assert processed[0]["title"] == "High Score"
assert processed[1]["title"] == "Medium Score"
def test_process_results_truncate_long_content(self, post_processor):
"""Test truncating long content"""
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["content"]) == 103 # 100 + "..."
assert processed[0]["content"].endswith("...")
def test_process_results_remove_base64_images(self, post_processor):
"""Test removing base64 images from content"""
content_with_base64 = (
"Content with image "
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
)
results = [
{
"type": "page",
"title": "Page with Base64",
"url": "https://example.com",
"content": content_with_base64,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["content"] == "Content with image "
def test_process_results_with_image_type(self, post_processor):
"""Test processing image type results"""
results = [
{
"type": "image",
"image_url": "https://example.com/image.jpg",
"image_description": "Test image",
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["type"] == "image"
assert processed[0]["image_url"] == "https://example.com/image.jpg"
assert processed[0]["image_description"] == "Test image"
def test_process_results_filter_base64_image_urls(self, post_processor):
"""Test filtering out image results with base64 URLs"""
results = [
{
"type": "image",
"image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
"image_description": "Base64 image",
},
{
"type": "image",
"image_url": "https://example.com/image.jpg",
"image_description": "Regular image",
},
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["image_url"] == "https://example.com/image.jpg"
def test_process_results_truncate_long_image_description(self, post_processor):
"""Test truncating long image descriptions"""
long_description = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "image",
"image_url": "https://example.com/image.jpg",
"image_description": long_description,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["image_description"]) == 103 # 100 + "..."
assert processed[0]["image_description"].endswith("...")
def test_process_results_other_types_passthrough(self, post_processor):
"""Test that other result types pass through unchanged"""
results = [
{
"type": "video",
"title": "Test Video",
"url": "https://example.com/video.mp4",
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert processed[0]["type"] == "video"
assert processed[0]["title"] == "Test Video"
def test_process_results_truncate_long_content_with_no_config(self):
"""Test truncating long content"""
post_processor = SearchResultPostProcessor(None, None)
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["content"]) == len("A" * 150)
def test_process_results_truncate_long_content_with_max_content_length_config(self):
"""Test truncating long content"""
post_processor = SearchResultPostProcessor(None, 100)
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.8,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 1
assert len(processed[0]["content"]) == 103
assert processed[0]["content"].endswith("...")
def test_process_results_truncate_long_content_with_min_score_config(self):
"""Test truncating long content"""
post_processor = SearchResultPostProcessor(0.8, None)
long_content = "A" * 150 # Longer than max_content_length of 100
results = [
{
"type": "page",
"title": "Long Content Page",
"url": "https://example.com",
"content": long_content,
"score": 0.3,
}
]
processed = post_processor.process_results(results)
assert len(processed) == 0

View File

@@ -0,0 +1,183 @@
import pytest
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
from src.utils.context_manager import ContextManager
class TestContextManager:
"""Test cases for ContextManager"""
def test_count_tokens_with_empty_messages(self):
"""Test counting tokens with empty message list"""
context_manager = ContextManager(token_limit=1000)
messages = []
token_count = context_manager.count_tokens(messages)
assert token_count == 0
def test_count_tokens_with_system_message(self):
"""Test counting tokens with system message"""
context_manager = ContextManager(token_limit=1000)
messages = [SystemMessage(content="You are a helpful assistant.")]
token_count = context_manager.count_tokens(messages)
# System message has 28 characters, should be around 8 tokens (28/4 * 1.1)
assert token_count > 7
def test_count_tokens_with_human_message(self):
"""Test counting tokens with human message"""
context_manager = ContextManager(token_limit=1000)
messages = [HumanMessage(content="你好,这是一个测试消息。")]
token_count = context_manager.count_tokens(messages)
assert token_count > 12
def test_count_tokens_with_ai_message(self):
"""Test counting tokens with AI message"""
context_manager = ContextManager(token_limit=1000)
messages = [AIMessage(content="I'm doing well, thank you for asking!")]
token_count = context_manager.count_tokens(messages)
assert token_count >= 10
def test_count_tokens_with_tool_message(self):
"""Test counting tokens with tool message"""
context_manager = ContextManager(token_limit=1000)
messages = [
ToolMessage(content="Tool execution result data here", tool_call_id="test")
]
token_count = context_manager.count_tokens(messages)
# Tool message has about 32 characters, should be around 10 tokens (32/4 * 1.3)
assert token_count > 0
def test_count_tokens_with_multiple_messages(self):
"""Test counting tokens with multiple messages"""
context_manager = ContextManager(token_limit=1000)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello, how are you?"),
AIMessage(content="I'm doing well, thank you for asking!"),
]
token_count = context_manager.count_tokens(messages)
# Should be sum of all individual message tokens
assert token_count > 0
def test_is_over_limit_when_under_limit(self):
"""Test is_over_limit when messages are under token limit"""
context_manager = ContextManager(token_limit=1000)
short_messages = [HumanMessage(content="Short message")]
is_over = context_manager.is_over_limit(short_messages)
assert is_over is False
def test_is_over_limit_when_over_limit(self):
"""Test is_over_limit when messages exceed token limit"""
# Create a context manager with a very low limit
low_limit_cm = ContextManager(token_limit=5)
long_messages = [
HumanMessage(
content="This is a very long message that should exceed the limit"
)
]
is_over = low_limit_cm.is_over_limit(long_messages)
assert is_over is True
def test_compress_messages_when_not_over_limit(self):
"""Test compress_messages when messages are not over limit"""
context_manager = ContextManager(token_limit=1000)
messages = [HumanMessage(content="Short message")]
compressed = context_manager.compress_messages({"messages": messages})
# Should return the same messages when not over limit
assert len(compressed["messages"]) == len(messages)
def test_compress_messages_with_system_message(self):
"""Test compress_messages preserves system message"""
# Create a context manager with limited token capacity
limited_cm = ContextManager(token_limit=200)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(
content="Can you tell me a very long story that would exceed token limits? "
* 100
),
]
compressed = limited_cm.compress_messages({"messages": messages})
# Should preserve system message and some recent messages
assert len(compressed["messages"]) == 1
def test_compress_messages_with_preserve_prefix_message(self):
"""Test compress_messages when no system message is present"""
# Create a context manager with limited token capacity
limited_cm = ContextManager(token_limit=100, preserve_prefix_message_count=2)
messages = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(
content="Can you tell me a very long story that would exceed token limits? "
* 10
),
]
compressed = limited_cm.compress_messages({"messages": messages})
# Should keep only the most recent messages that fit
assert len(compressed["messages"]) == 3
def test_compress_messages_without_config(self):
"""Test compress_messages preserves system message"""
# Create a context manager with limited token capacity
limited_cm = ContextManager(None)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(
content="Can you tell me a very long story that would exceed token limits? "
* 100
),
]
compressed = limited_cm.compress_messages({"messages": messages})
# return the original messages
assert len(compressed["messages"]) == 4
def test_count_message_tokens_with_additional_kwargs(self):
"""Test counting tokens for messages with additional kwargs"""
context_manager = ContextManager(token_limit=1000)
message = ToolMessage(
content="Tool result",
tool_call_id="test",
additional_kwargs={"tool_calls": [{"name": "test_function"}]},
)
token_count = context_manager._count_message_tokens(message)
assert token_count > 0
def test_count_message_tokens_minimum_one_token(self):
"""Test that message token count is at least 1"""
context_manager = ContextManager(token_limit=1000)
message = HumanMessage(content="") # Empty content
token_count = context_manager._count_message_tokens(message)
assert token_count == 1 # Should be at least 1
def test_count_text_tokens_english_only(self):
"""Test counting tokens for English text"""
context_manager = ContextManager(token_limit=1000)
# 16 English characters should result in 4 tokens (16/4)
text = "This is a test."
token_count = context_manager._count_text_tokens(text)
assert token_count > 0
def test_count_text_tokens_chinese_only(self):
"""Test counting tokens for Chinese text"""
context_manager = ContextManager(token_limit=1000)
# 8 Chinese characters should result in 8 tokens (1:1 ratio)
text = "这是一个测试文本"
token_count = context_manager._count_text_tokens(text)
assert token_count == 8
def test_count_text_tokens_mixed_content(self):
"""Test counting tokens for mixed English and Chinese text"""
context_manager = ContextManager(token_limit=1000)
text = "Hello world 这是一些中文"
token_count = context_manager._count_text_tokens(text)
assert token_count > 6