From 5f4eb38fdbf5ede5b45aca08284f644194143bd7 Mon Sep 17 00:00:00 2001 From: Fancy-hjyp <53164956+Fancy-hjyp@users.noreply.github.com> Date: Sat, 27 Sep 2025 06:42:22 -0700 Subject: [PATCH] feat: add context compress (#590) * feat:Add context compress * feat: Add unit test * feat: add unit test for context manager * feat: add postprocessor param && code format * feat: add configuration guide * fix: fix the configuration_guide * fix: fix the unit test * fix: fix the default value * feat: add test and log for context_manager --- docs/configuration_guide.md | 38 ++- src/agents/agents.py | 9 +- src/graph/nodes.py | 30 +- src/llms/llm.py | 23 ++ src/tools/search_postprocessor.py | 212 ++++++++++++++ .../tavily_search_api_wrapper.py | 17 ++ src/utils/context_manager.py | 265 ++++++++++++++++++ tests/unit/tools/test_search_postprocessor.py | 262 +++++++++++++++++ tests/unit/utils/test_context_manager.py | 183 ++++++++++++ 9 files changed, 1032 insertions(+), 7 deletions(-) create mode 100644 src/tools/search_postprocessor.py create mode 100644 src/utils/context_manager.py create mode 100644 tests/unit/tools/test_search_postprocessor.py create mode 100644 tests/unit/utils/test_context_manager.py diff --git a/docs/configuration_guide.md b/docs/configuration_guide.md index f02f53a..d7ec677 100644 --- a/docs/configuration_guide.md +++ b/docs/configuration_guide.md @@ -180,6 +180,20 @@ BASIC_MODEL: api_key: $AZURE_OPENAI_API_KEY ``` +### How to configure context length for different models + +Different models have different context length limitations. DeerFlow provides a method to control the context length between different models. You can configure the context length between different models in the `conf.yaml` file. For example: +```yaml +BASIC_MODEL: + base_url: https://ark.cn-beijing.volces.com/api/v3 + model: "doubao-1-5-pro-32k-250115" + api_key: "" + token_limit: 128000 +``` +This means that the context length limit using this model is 128k. + +The context management doesn't work if the token_limit is not set. + ## About Search Engine ### How to control search domains for Tavily? @@ -210,6 +224,28 @@ SEARCH_ENGINE: include_raw_content: false ``` +### How to post-process Tavily search results + +DeerFlow can post-process Tavily search results: +* Remove duplicate content +* Filter low-quality content: Filter out results with low relevance scores +* Clear base64 encoded images +* Length truncation: Truncate each search result according to the user-configured length + +The filtering of low-quality content and length truncation depend on user configuration, providing two configurable parameters: +* min_score_threshold: Minimum relevance score threshold, search results below this threshold will be filtered. If not set, no filtering will be performed; +* max_content_length_per_page: Maximum length limit for each search result content, parts exceeding this length will be truncated. If not set, no truncation will be performed; + +These two parameters can be configured in `conf.yaml` as shown below: +```yaml +SEARCH_ENGINE: + engine: tavily + include_images: true + min_score_threshold: 0.4 + max_content_length_per_page: 5000 +``` +That's meaning that the search results will be filtered based on the minimum relevance score threshold and truncated to the maximum length limit for each search result content. + ## RAG (Retrieval-Augmented Generation) Configuration DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables. @@ -244,4 +280,4 @@ MILVUS_EMBEDDING_PROVIDER=openai MILVUS_EMBEDDING_BASE_URL= MILVUS_EMBEDDING_MODEL= MILVUS_EMBEDDING_API_KEY= -``` +``` \ No newline at end of file diff --git a/src/agents/agents.py b/src/agents/agents.py index 53c10f5..0b4f54d 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -9,11 +9,18 @@ from src.prompts import apply_prompt_template # Create agents using configured LLM types -def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template: str): +def create_agent( + agent_name: str, + agent_type: str, + tools: list, + prompt_template: str, + pre_model_hook: callable = None, +): """Factory function to create agents with consistent configuration.""" return create_react_agent( name=agent_name, model=get_llm_by_type(AGENT_LLM_MAP[agent_type]), tools=tools, prompt=lambda state: apply_prompt_template(prompt_template, state), + pre_model_hook=pre_model_hook, ) diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 12dc9bf..179cb8b 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -6,16 +6,17 @@ import logging import os from typing import Annotated, Literal -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import tool from langchain_mcp_adapters.client import MultiServerMCPClient from langgraph.types import Command, interrupt +from functools import partial from src.agents import create_agent from src.config.agents import AGENT_LLM_MAP from src.config.configuration import Configuration -from src.llms.llm import get_llm_by_type +from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type from src.prompts.planner_model import Plan from src.prompts.template import apply_prompt_template from src.tools import ( @@ -26,6 +27,7 @@ from src.tools import ( ) from src.tools.search import LoggedTavilySearch from src.utils.json_utils import repair_json_output +from src.utils.context_manager import ContextManager from ..config import SELECTED_SEARCH_ENGINE, SearchEngine from .types import State @@ -283,13 +285,22 @@ def reporter_node(state: State, config: RunnableConfig): ) ) + observation_messages = [] for observation in observations: - invoke_messages.append( + observation_messages.append( HumanMessage( content=f"Below are some observations for the research task:\n\n{observation}", name="observation", ) ) + + # Context compression + llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP["reporter"]) + compressed_state = ContextManager(llm_token_limit).compress_messages( + {"messages": observation_messages} + ) + invoke_messages += compressed_state.get("messages", []) + logger.debug(f"Current invoke messages: {invoke_messages}") response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages) response_content = response.content @@ -469,11 +480,20 @@ async def _setup_and_execute_agent_step( f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}" ) loaded_tools.append(tool) - agent = create_agent(agent_type, agent_type, loaded_tools, agent_type) + + llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type]) + pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages) + agent = create_agent( + agent_type, agent_type, loaded_tools, agent_type, pre_model_hook + ) return await _execute_agent_step(state, agent, agent_type) else: # Use default tools if no MCP servers are configured - agent = create_agent(agent_type, agent_type, default_tools, agent_type) + llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type]) + pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages) + agent = create_agent( + agent_type, agent_type, default_tools, agent_type, pre_model_hook + ) return await _execute_agent_step(state, agent, agent_type) diff --git a/src/llms/llm.py b/src/llms/llm.py index 7fd1f54..809bebe 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -67,6 +67,10 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatMod # Merge configurations, with environment variables taking precedence merged_conf = {**llm_conf, **env_conf} + # Remove unnecessary parameters when initializing the client + if "token_limit" in merged_conf: + merged_conf.pop("token_limit") + if not merged_conf: raise ValueError(f"No configuration found for LLM type: {llm_type}") @@ -174,6 +178,25 @@ def get_configured_llm_models() -> dict[str, list[str]]: return {} +def get_llm_token_limit_by_type(llm_type: str) -> int: + """ + Get the maximum token limit for a given LLM type. + + Args: + llm_type (str): The type of LLM. + + Returns: + int: The maximum token limit for the specified LLM type. + """ + + llm_type_config_keys = _get_llm_type_config_keys() + config_key = llm_type_config_keys.get(llm_type) + + conf = load_yaml_config(_get_config_file_path()) + llm_max_token = conf.get(config_key, {}).get("token_limit") + return llm_max_token + + # In the future, we will use reasoning_llm and vl_llm for different purposes # reasoning_llm = get_llm_by_type("reasoning") # vl_llm = get_llm_by_type("vision") diff --git a/src/tools/search_postprocessor.py b/src/tools/search_postprocessor.py new file mode 100644 index 0000000..0f7719e --- /dev/null +++ b/src/tools/search_postprocessor.py @@ -0,0 +1,212 @@ +# src/tools/search_postprocessor.py +import re +import base64 +import logging +from typing import List, Dict, Any +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +class SearchResultPostProcessor: + """Search result post-processor""" + + base64_pattern = r"data:image/[^;]+;base64,[a-zA-Z0-9+/=]+" + + def __init__(self, min_score_threshold: float, max_content_length_per_page: int): + """ + Initialize the post-processor + + Args: + min_score_threshold: Minimum relevance score threshold + max_content_length_per_page: Maximum content length + """ + self.min_score_threshold = min_score_threshold + self.max_content_length_per_page = max_content_length_per_page + + def process_results(self, results: List[Dict]) -> List[Dict]: + """ + Process search results + + Args: + results: Original search result list + + Returns: + Processed result list + """ + if not results: + return [] + + # Combined processing in a single loop for efficiency + cleaned_results = [] + seen_urls = set() + + for result in results: + # 1. Remove duplicates + cleaned_result = self._remove_duplicates(result, seen_urls) + if not cleaned_result: + continue + + # 2. Filter low quality results + if ( + "page" == cleaned_result.get("type") + and self.min_score_threshold + and self.min_score_threshold > 0 + and cleaned_result.get("score", 0) < self.min_score_threshold + ): + continue + + # 3. Clean base64 images from content + cleaned_result = self._remove_base64_images(cleaned_result) + if not cleaned_result: + continue + + # 4. When max_content_length_per_page is set, truncate long content + if ( + self.max_content_length_per_page + and self.max_content_length_per_page > 0 + ): + cleaned_result = self._truncate_long_content(cleaned_result) + + if cleaned_result: + cleaned_results.append(cleaned_result) + + # 5. Sort (by score descending) + sorted_results = sorted( + cleaned_results, key=lambda x: x.get("score", 0), reverse=True + ) + + logger.info( + f"Search result post-processing: {len(results)} -> {len(sorted_results)}" + ) + return sorted_results + + def _remove_base64_images(self, result: Dict) -> Dict: + """Remove base64 encoded images from content""" + + if "page" == result.get("type"): + cleaned_result = self.processPage(result) + elif "image" == result.get("type"): + cleaned_result = self.processImage(result) + else: + # For other types, keep as is + cleaned_result = result.copy() + + return cleaned_result + + def processPage(self, result: Dict) -> Dict: + """Process page type result""" + # Clean base64 images from content + cleaned_result = result.copy() + + if "content" in result: + original_content = result["content"] + cleaned_content = re.sub(self.base64_pattern, " ", original_content) + cleaned_result["content"] = cleaned_content + + # Log if significant content was removed + if len(cleaned_content) < len(original_content) * 0.8: + logger.debug( + f"Removed base64 images from search content: {result.get('url', 'unknown')}" + ) + + # Clean base64 images from raw content + if "raw_content" in cleaned_result: + original_raw_content = cleaned_result["raw_content"] + cleaned_raw_content = re.sub(self.base64_pattern, " ", original_raw_content) + cleaned_result["raw_content"] = cleaned_raw_content + + # Log if significant content was removed + if len(cleaned_raw_content) < len(original_raw_content) * 0.8: + logger.debug( + f"Removed base64 images from search raw content: {result.get('url', 'unknown')}" + ) + + return cleaned_result + + def processImage(self, result: Dict) -> Dict: + """Process image type result - clean up base64 data and long fields""" + cleaned_result = result.copy() + + # Remove base64 encoded data from image_url if present + if "image_url" in cleaned_result and isinstance( + cleaned_result["image_url"], str + ): + # Check if image_url contains base64 data + if "data:image" in cleaned_result["image_url"]: + original_image_url = cleaned_result["image_url"] + cleaned_image_url = re.sub(self.base64_pattern, " ", original_image_url) + if len(cleaned_image_url) == 0 or not cleaned_image_url.startswith( + "http" + ): + logger.debug( + f"Removed base64 data from image_url and the cleaned_image_url is empty or not start with http, origin image_url: {result.get('image_url', 'unknown')}" + ) + return {} + cleaned_result["image_url"] = cleaned_image_url + logger.debug( + f"Removed base64 data from image_url: {result.get('image_url', 'unknown')}" + ) + + # Truncate very long image descriptions + if "image_description" in cleaned_result and isinstance( + cleaned_result["image_description"], str + ): + if ( + self.max_content_length_per_page + and len(cleaned_result["image_description"]) + > self.max_content_length_per_page + ): + cleaned_result["image_description"] = ( + cleaned_result["image_description"][ + : self.max_content_length_per_page + ] + + "..." + ) + logger.info( + f"Truncated long image description from search result: {result.get('image_url', 'unknown')}" + ) + + return cleaned_result + + def _truncate_long_content(self, result: Dict) -> Dict: + """Truncate long content""" + + truncated_result = result.copy() + + # Truncate content length + if "content" in truncated_result: + content = truncated_result["content"] + if len(content) > self.max_content_length_per_page: + truncated_result["content"] = ( + content[: self.max_content_length_per_page] + "..." + ) + logger.info( + f"Truncated long content from search result: {result.get('url', 'unknown')}" + ) + + # Truncate raw content length (can be slightly longer) + if "raw_content" in truncated_result: + raw_content = truncated_result["raw_content"] + if len(raw_content) > self.max_content_length_per_page * 2: + truncated_result["raw_content"] = ( + raw_content[: self.max_content_length_per_page * 2] + "..." + ) + logger.info( + f"Truncated long raw content from search result: {result.get('url', 'unknown')}" + ) + + return truncated_result + + def _remove_duplicates(self, result: Dict, seen_urls: set) -> Dict: + """Remove duplicate results""" + + url = result.get("url", result.get("image_url", "")) + if url and url not in seen_urls: + seen_urls.add(url) + return result.copy() # Return a copy to avoid modifying original + elif not url: + # Keep results with empty URLs + return result.copy() # Return a copy to avoid modifying original + + return {} # Return empty dict for duplicates diff --git a/src/tools/tavily_search/tavily_search_api_wrapper.py b/src/tools/tavily_search/tavily_search_api_wrapper.py index f1945a5..f42aa35 100644 --- a/src/tools/tavily_search/tavily_search_api_wrapper.py +++ b/src/tools/tavily_search/tavily_search_api_wrapper.py @@ -11,6 +11,14 @@ from langchain_tavily._utilities import TAVILY_API_URL from langchain_tavily.tavily_search import ( TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper, ) +from src.tools.search_postprocessor import SearchResultPostProcessor +from src.config import load_yaml_config + + +def get_search_config(): + config = load_yaml_config("conf.yaml") + search_config = config.get("SEARCH_ENGINE", {}) + return search_config class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper): @@ -110,4 +118,13 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper): "image_description": image["description"], } clean_results.append(clean_result) + + search_config = get_search_config() + clean_results = SearchResultPostProcessor( + min_score_threshold=search_config.get("min_score_threshold"), + max_content_length_per_page=search_config.get( + "max_content_length_per_page" + ), + ).process_results(clean_results) + return clean_results diff --git a/src/utils/context_manager.py b/src/utils/context_manager.py new file mode 100644 index 0000000..8015a1a --- /dev/null +++ b/src/utils/context_manager.py @@ -0,0 +1,265 @@ +# src/utils/token_manager.py +from typing import List +from langchain_core.messages import ( + BaseMessage, + HumanMessage, + AIMessage, + ToolMessage, + SystemMessage, +) +import logging +import copy + +from src.config import load_yaml_config + +logger = logging.getLogger(__name__) + + +def get_search_config(): + config = load_yaml_config("conf.yaml") + search_config = config.get("MODEL_TOKEN_LIMITS", {}) + return search_config + + +class ContextManager: + """Context manager and compression class""" + + def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0): + """ + Initialize ContextManager + + Args: + token_limit: Maximum token limit + preserve_prefix_message_count: Number of messages to preserve at the beginning of the context + """ + self.token_limit = token_limit + self.preserve_prefix_message_count = preserve_prefix_message_count + + def count_tokens(self, messages: List[BaseMessage]) -> int: + """ + Count tokens in message list + + Args: + messages: List of messages + + Returns: + Number of tokens + """ + total_tokens = 0 + for message in messages: + total_tokens += self._count_message_tokens(message) + return total_tokens + + def _count_message_tokens(self, message: BaseMessage) -> int: + """ + Count tokens in a single message + + Args: + message: Message object + + Returns: + Number of tokens + """ + # Estimate token count based on character length (different calculation for English and non-English) + token_count = 0 + + # Count tokens in content field + if hasattr(message, "content") and message.content: + # Handle different content types + if isinstance(message.content, str): + token_count += self._count_text_tokens(message.content) + + # Count role-related tokens + if hasattr(message, "type"): + token_count += self._count_text_tokens(message.type) + + # Special handling for different message types + if isinstance(message, SystemMessage): + # System messages are usually short but important, slightly increase estimate + token_count = int(token_count * 1.1) + elif isinstance(message, HumanMessage): + # Human messages use normal estimation + pass + elif isinstance(message, AIMessage): + # AI messages may contain reasoning content, slightly increase estimate + token_count = int(token_count * 1.2) + elif isinstance(message, ToolMessage): + # Tool messages may contain large amounts of structured data, increase estimate + token_count = int(token_count * 1.3) + + # Process additional information in additional_kwargs + if hasattr(message, "additional_kwargs") and message.additional_kwargs: + # Simple estimation of extra field tokens + extra_str = str(message.additional_kwargs) + token_count += self._count_text_tokens(extra_str) + + # If there are tool_calls, add estimation + if "tool_calls" in message.additional_kwargs: + token_count += 50 # Add estimation for function call information + + # Ensure at least 1 token + return max(1, token_count) + + def _count_text_tokens(self, text: str) -> int: + """ + Count tokens in text with different calculations for English and non-English characters. + English characters: 4 characters ≈ 1 token + Non-English characters (e.g., Chinese): 1 character ≈ 1 token + + Args: + text: Text to count tokens for + + Returns: + Number of tokens + """ + if not text: + return 0 + + english_chars = 0 + non_english_chars = 0 + + for char in text: + # Check if character is ASCII (English letters, digits, punctuation) + if ord(char) < 128: + english_chars += 1 + else: + non_english_chars += 1 + + # Calculate tokens: English at 4 chars/token, others at 1 char/token + english_tokens = english_chars // 4 + non_english_tokens = non_english_chars + + return english_tokens + non_english_tokens + + def is_over_limit(self, messages: List[BaseMessage]) -> bool: + """ + Check if messages exceed token limit + + Args: + messages: List of messages + + Returns: + Whether limit is exceeded + """ + return self.count_tokens(messages) > self.token_limit + + def compress_messages(self, state: dict) -> List[BaseMessage]: + """ + Compress messages to fit within token limit + + Args: + state: state with original messages + + Returns: + Compressed state with compressed messages + """ + # If not set token_limit, return original state + if self.token_limit is None: + logger.info("No token_limit set, the context management doesn't work.") + return state + + if not isinstance(state, dict) or "messages" not in state: + logger.warning("No messages found in state") + return state + + messages = state["messages"] + + if not self.is_over_limit(messages): + return state + + # 2. Compress messages + compressed_messages = self._compress_messages(messages) + + logger.info( + f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens" + ) + + state["messages"] = compressed_messages + return state + + def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: + """ + Compress compressible messages + + Args: + messages: List of messages to compress + + Returns: + Compressed message list + """ + + available_token = self.token_limit + prefix_messages = [] + + # 1. Preserve head messages of specified length to retain system prompts and user input + for i in range(min(self.preserve_prefix_message_count, len(messages))): + cur_token_cnt = self._count_message_tokens(messages[i]) + if available_token > 0 and available_token >= cur_token_cnt: + prefix_messages.append(messages[i]) + available_token -= cur_token_cnt + elif available_token > 0: + # Truncate content to fit available tokens + truncated_message = self._truncate_message_content( + messages[i], available_token + ) + prefix_messages.append(truncated_message) + return prefix_messages + else: + break + + # 2. Compress subsequent messages from the tail, some messages may be discarded + messages = messages[len(prefix_messages) :] + suffix_messages = [] + for i in range(len(messages) - 1, -1, -1): + cur_token_cnt = self._count_message_tokens(messages[i]) + + if cur_token_cnt > 0 and available_token >= cur_token_cnt: + suffix_messages = [messages[i]] + suffix_messages + available_token -= cur_token_cnt + elif available_token > 0: + # Truncate content to fit available tokens + truncated_message = self._truncate_message_content( + messages[i], available_token + ) + suffix_messages = [truncated_message] + suffix_messages + return prefix_messages + suffix_messages + else: + break + + return prefix_messages + suffix_messages + + def _truncate_message_content( + self, message: BaseMessage, max_tokens: int + ) -> BaseMessage: + """ + Truncate message content while preserving all other attributes by copying the original message + and only modifying its content attribute. + + Args: + message: The message to truncate + max_tokens: Maximum number of tokens to keep + + Returns: + New message instance with truncated content + """ + + # Create a deep copy of the original message to preserve all attributes + truncated_message = copy.deepcopy(message) + + # Truncate only the content attribute + truncated_message.content = message.content[:max_tokens] + + return truncated_message + + def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage: + """ + Create summary for messages + + Args: + messages: Messages to summarize + + Returns: + Summary message + """ + # TODO: summary implementation + pass diff --git a/tests/unit/tools/test_search_postprocessor.py b/tests/unit/tools/test_search_postprocessor.py new file mode 100644 index 0000000..5064b25 --- /dev/null +++ b/tests/unit/tools/test_search_postprocessor.py @@ -0,0 +1,262 @@ +import pytest +from src.tools.search_postprocessor import SearchResultPostProcessor + + +class TestSearchResultPostProcessor: + """Test cases for SearchResultPostProcessor""" + + @pytest.fixture + def post_processor(self): + """Create a SearchResultPostProcessor instance for testing""" + return SearchResultPostProcessor( + min_score_threshold=0.5, max_content_length_per_page=100 + ) + + def test_process_results_empty_input(self, post_processor): + """Test processing empty results""" + results = [] + processed = post_processor.process_results(results) + assert processed == [] + + def test_process_results_with_valid_page_results(self, post_processor): + """Test processing valid page results""" + results = [ + { + "type": "page", + "title": "Test Page", + "url": "https://example.com", + "content": "Test content", + "score": 0.8, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert processed[0]["title"] == "Test Page" + assert processed[0]["url"] == "https://example.com" + assert processed[0]["content"] == "Test content" + assert processed[0]["score"] == 0.8 + + def test_process_results_filter_low_score(self, post_processor): + """Test filtering out low score results""" + results = [ + { + "type": "page", + "title": "Low Score Page", + "url": "https://example.com/low", + "content": "Low score content", + "score": 0.3, # Below threshold of 0.5 + }, + { + "type": "page", + "title": "High Score Page", + "url": "https://example.com/high", + "content": "High score content", + "score": 0.9, + }, + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert processed[0]["title"] == "High Score Page" + + def test_process_results_remove_duplicates(self, post_processor): + """Test removing duplicate URLs""" + results = [ + { + "type": "page", + "title": "Page 1", + "url": "https://example.com", + "content": "Content 1", + "score": 0.8, + }, + { + "type": "page", + "title": "Page 2", + "url": "https://example.com", # Duplicate URL + "content": "Content 2", + "score": 0.7, + }, + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert processed[0]["title"] == "Page 1" # First one should be kept + + def test_process_results_sort_by_score(self, post_processor): + """Test sorting results by score in descending order""" + results = [ + { + "type": "page", + "title": "Low Score", + "url": "https://example.com/low", + "content": "Low score content", + "score": 0.3, + }, + { + "type": "page", + "title": "High Score", + "url": "https://example.com/high", + "content": "High score content", + "score": 0.9, + }, + { + "type": "page", + "title": "Medium Score", + "url": "https://example.com/medium", + "content": "Medium score content", + "score": 0.6, + }, + ] + processed = post_processor.process_results(results) + assert len(processed) == 2 # Low score filtered out + # Should be sorted by score descending + assert processed[0]["title"] == "High Score" + assert processed[1]["title"] == "Medium Score" + + def test_process_results_truncate_long_content(self, post_processor): + """Test truncating long content""" + long_content = "A" * 150 # Longer than max_content_length of 100 + results = [ + { + "type": "page", + "title": "Long Content Page", + "url": "https://example.com", + "content": long_content, + "score": 0.8, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert len(processed[0]["content"]) == 103 # 100 + "..." + assert processed[0]["content"].endswith("...") + + def test_process_results_remove_base64_images(self, post_processor): + """Test removing base64 images from content""" + content_with_base64 = ( + "Content with image " + + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==" + ) + results = [ + { + "type": "page", + "title": "Page with Base64", + "url": "https://example.com", + "content": content_with_base64, + "score": 0.8, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert processed[0]["content"] == "Content with image " + + def test_process_results_with_image_type(self, post_processor): + """Test processing image type results""" + results = [ + { + "type": "image", + "image_url": "https://example.com/image.jpg", + "image_description": "Test image", + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert processed[0]["type"] == "image" + assert processed[0]["image_url"] == "https://example.com/image.jpg" + assert processed[0]["image_description"] == "Test image" + + def test_process_results_filter_base64_image_urls(self, post_processor): + """Test filtering out image results with base64 URLs""" + results = [ + { + "type": "image", + "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==", + "image_description": "Base64 image", + }, + { + "type": "image", + "image_url": "https://example.com/image.jpg", + "image_description": "Regular image", + }, + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert processed[0]["image_url"] == "https://example.com/image.jpg" + + def test_process_results_truncate_long_image_description(self, post_processor): + """Test truncating long image descriptions""" + long_description = "A" * 150 # Longer than max_content_length of 100 + results = [ + { + "type": "image", + "image_url": "https://example.com/image.jpg", + "image_description": long_description, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert len(processed[0]["image_description"]) == 103 # 100 + "..." + assert processed[0]["image_description"].endswith("...") + + def test_process_results_other_types_passthrough(self, post_processor): + """Test that other result types pass through unchanged""" + results = [ + { + "type": "video", + "title": "Test Video", + "url": "https://example.com/video.mp4", + "score": 0.8, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert processed[0]["type"] == "video" + assert processed[0]["title"] == "Test Video" + + def test_process_results_truncate_long_content_with_no_config(self): + """Test truncating long content""" + post_processor = SearchResultPostProcessor(None, None) + long_content = "A" * 150 # Longer than max_content_length of 100 + results = [ + { + "type": "page", + "title": "Long Content Page", + "url": "https://example.com", + "content": long_content, + "score": 0.8, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert len(processed[0]["content"]) == len("A" * 150) + + def test_process_results_truncate_long_content_with_max_content_length_config(self): + """Test truncating long content""" + post_processor = SearchResultPostProcessor(None, 100) + long_content = "A" * 150 # Longer than max_content_length of 100 + results = [ + { + "type": "page", + "title": "Long Content Page", + "url": "https://example.com", + "content": long_content, + "score": 0.8, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 1 + assert len(processed[0]["content"]) == 103 + assert processed[0]["content"].endswith("...") + + def test_process_results_truncate_long_content_with_min_score_config(self): + """Test truncating long content""" + post_processor = SearchResultPostProcessor(0.8, None) + long_content = "A" * 150 # Longer than max_content_length of 100 + results = [ + { + "type": "page", + "title": "Long Content Page", + "url": "https://example.com", + "content": long_content, + "score": 0.3, + } + ] + processed = post_processor.process_results(results) + assert len(processed) == 0 diff --git a/tests/unit/utils/test_context_manager.py b/tests/unit/utils/test_context_manager.py new file mode 100644 index 0000000..ed213ec --- /dev/null +++ b/tests/unit/utils/test_context_manager.py @@ -0,0 +1,183 @@ +import pytest +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage +from src.utils.context_manager import ContextManager + + +class TestContextManager: + """Test cases for ContextManager""" + + def test_count_tokens_with_empty_messages(self): + """Test counting tokens with empty message list""" + context_manager = ContextManager(token_limit=1000) + messages = [] + token_count = context_manager.count_tokens(messages) + assert token_count == 0 + + def test_count_tokens_with_system_message(self): + """Test counting tokens with system message""" + context_manager = ContextManager(token_limit=1000) + messages = [SystemMessage(content="You are a helpful assistant.")] + token_count = context_manager.count_tokens(messages) + # System message has 28 characters, should be around 8 tokens (28/4 * 1.1) + assert token_count > 7 + + def test_count_tokens_with_human_message(self): + """Test counting tokens with human message""" + context_manager = ContextManager(token_limit=1000) + messages = [HumanMessage(content="你好,这是一个测试消息。")] + token_count = context_manager.count_tokens(messages) + assert token_count > 12 + + def test_count_tokens_with_ai_message(self): + """Test counting tokens with AI message""" + context_manager = ContextManager(token_limit=1000) + messages = [AIMessage(content="I'm doing well, thank you for asking!")] + token_count = context_manager.count_tokens(messages) + assert token_count >= 10 + + def test_count_tokens_with_tool_message(self): + """Test counting tokens with tool message""" + context_manager = ContextManager(token_limit=1000) + messages = [ + ToolMessage(content="Tool execution result data here", tool_call_id="test") + ] + token_count = context_manager.count_tokens(messages) + # Tool message has about 32 characters, should be around 10 tokens (32/4 * 1.3) + assert token_count > 0 + + def test_count_tokens_with_multiple_messages(self): + """Test counting tokens with multiple messages""" + context_manager = ContextManager(token_limit=1000) + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hello, how are you?"), + AIMessage(content="I'm doing well, thank you for asking!"), + ] + token_count = context_manager.count_tokens(messages) + # Should be sum of all individual message tokens + assert token_count > 0 + + def test_is_over_limit_when_under_limit(self): + """Test is_over_limit when messages are under token limit""" + context_manager = ContextManager(token_limit=1000) + short_messages = [HumanMessage(content="Short message")] + is_over = context_manager.is_over_limit(short_messages) + assert is_over is False + + def test_is_over_limit_when_over_limit(self): + """Test is_over_limit when messages exceed token limit""" + # Create a context manager with a very low limit + low_limit_cm = ContextManager(token_limit=5) + long_messages = [ + HumanMessage( + content="This is a very long message that should exceed the limit" + ) + ] + is_over = low_limit_cm.is_over_limit(long_messages) + assert is_over is True + + def test_compress_messages_when_not_over_limit(self): + """Test compress_messages when messages are not over limit""" + context_manager = ContextManager(token_limit=1000) + messages = [HumanMessage(content="Short message")] + compressed = context_manager.compress_messages({"messages": messages}) + # Should return the same messages when not over limit + assert len(compressed["messages"]) == len(messages) + + def test_compress_messages_with_system_message(self): + """Test compress_messages preserves system message""" + # Create a context manager with limited token capacity + limited_cm = ContextManager(token_limit=200) + + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + HumanMessage( + content="Can you tell me a very long story that would exceed token limits? " + * 100 + ), + ] + + compressed = limited_cm.compress_messages({"messages": messages}) + # Should preserve system message and some recent messages + assert len(compressed["messages"]) == 1 + + def test_compress_messages_with_preserve_prefix_message(self): + """Test compress_messages when no system message is present""" + # Create a context manager with limited token capacity + limited_cm = ContextManager(token_limit=100, preserve_prefix_message_count=2) + + messages = [ + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + HumanMessage( + content="Can you tell me a very long story that would exceed token limits? " + * 10 + ), + ] + + compressed = limited_cm.compress_messages({"messages": messages}) + # Should keep only the most recent messages that fit + assert len(compressed["messages"]) == 3 + + def test_compress_messages_without_config(self): + """Test compress_messages preserves system message""" + # Create a context manager with limited token capacity + limited_cm = ContextManager(None) + + messages = [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + HumanMessage( + content="Can you tell me a very long story that would exceed token limits? " + * 100 + ), + ] + + compressed = limited_cm.compress_messages({"messages": messages}) + # return the original messages + assert len(compressed["messages"]) == 4 + + + def test_count_message_tokens_with_additional_kwargs(self): + """Test counting tokens for messages with additional kwargs""" + context_manager = ContextManager(token_limit=1000) + message = ToolMessage( + content="Tool result", + tool_call_id="test", + additional_kwargs={"tool_calls": [{"name": "test_function"}]}, + ) + token_count = context_manager._count_message_tokens(message) + assert token_count > 0 + + def test_count_message_tokens_minimum_one_token(self): + """Test that message token count is at least 1""" + context_manager = ContextManager(token_limit=1000) + message = HumanMessage(content="") # Empty content + token_count = context_manager._count_message_tokens(message) + assert token_count == 1 # Should be at least 1 + + def test_count_text_tokens_english_only(self): + """Test counting tokens for English text""" + context_manager = ContextManager(token_limit=1000) + # 16 English characters should result in 4 tokens (16/4) + text = "This is a test." + token_count = context_manager._count_text_tokens(text) + assert token_count > 0 + + def test_count_text_tokens_chinese_only(self): + """Test counting tokens for Chinese text""" + context_manager = ContextManager(token_limit=1000) + # 8 Chinese characters should result in 8 tokens (1:1 ratio) + text = "这是一个测试文本" + token_count = context_manager._count_text_tokens(text) + assert token_count == 8 + + def test_count_text_tokens_mixed_content(self): + """Test counting tokens for mixed English and Chinese text""" + context_manager = ContextManager(token_limit=1000) + text = "Hello world 这是一些中文" + token_count = context_manager._count_text_tokens(text) + assert token_count > 6