feat: add context compress (#590)

* feat:Add context compress

* feat: Add unit test

* feat: add unit test for context manager

* feat: add postprocessor param && code format

* feat: add configuration guide

* fix: fix the configuration_guide

* fix: fix the unit test

* fix: fix the default value

* feat: add test and log for context_manager
This commit is contained in:
Fancy-hjyp
2025-09-27 06:42:22 -07:00
committed by GitHub
parent c214999606
commit 5f4eb38fdb
9 changed files with 1032 additions and 7 deletions

View File

@@ -9,11 +9,18 @@ from src.prompts import apply_prompt_template
# Create agents using configured LLM types
def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template: str):
def create_agent(
agent_name: str,
agent_type: str,
tools: list,
prompt_template: str,
pre_model_hook: callable = None,
):
"""Factory function to create agents with consistent configuration."""
return create_react_agent(
name=agent_name,
model=get_llm_by_type(AGENT_LLM_MAP[agent_type]),
tools=tools,
prompt=lambda state: apply_prompt_template(prompt_template, state),
pre_model_hook=pre_model_hook,
)

View File

@@ -6,16 +6,17 @@ import logging
import os
from typing import Annotated, Literal
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.types import Command, interrupt
from functools import partial
from src.agents import create_agent
from src.config.agents import AGENT_LLM_MAP
from src.config.configuration import Configuration
from src.llms.llm import get_llm_by_type
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
from src.prompts.planner_model import Plan
from src.prompts.template import apply_prompt_template
from src.tools import (
@@ -26,6 +27,7 @@ from src.tools import (
)
from src.tools.search import LoggedTavilySearch
from src.utils.json_utils import repair_json_output
from src.utils.context_manager import ContextManager
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
from .types import State
@@ -283,13 +285,22 @@ def reporter_node(state: State, config: RunnableConfig):
)
)
observation_messages = []
for observation in observations:
invoke_messages.append(
observation_messages.append(
HumanMessage(
content=f"Below are some observations for the research task:\n\n{observation}",
name="observation",
)
)
# Context compression
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP["reporter"])
compressed_state = ContextManager(llm_token_limit).compress_messages(
{"messages": observation_messages}
)
invoke_messages += compressed_state.get("messages", [])
logger.debug(f"Current invoke messages: {invoke_messages}")
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
response_content = response.content
@@ -469,11 +480,20 @@ async def _setup_and_execute_agent_step(
f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}"
)
loaded_tools.append(tool)
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, loaded_tools, agent_type, pre_model_hook
)
return await _execute_agent_step(state, agent, agent_type)
else:
# Use default tools if no MCP servers are configured
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, default_tools, agent_type, pre_model_hook
)
return await _execute_agent_step(state, agent, agent_type)

View File

@@ -67,6 +67,10 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatMod
# Merge configurations, with environment variables taking precedence
merged_conf = {**llm_conf, **env_conf}
# Remove unnecessary parameters when initializing the client
if "token_limit" in merged_conf:
merged_conf.pop("token_limit")
if not merged_conf:
raise ValueError(f"No configuration found for LLM type: {llm_type}")
@@ -174,6 +178,25 @@ def get_configured_llm_models() -> dict[str, list[str]]:
return {}
def get_llm_token_limit_by_type(llm_type: str) -> int:
"""
Get the maximum token limit for a given LLM type.
Args:
llm_type (str): The type of LLM.
Returns:
int: The maximum token limit for the specified LLM type.
"""
llm_type_config_keys = _get_llm_type_config_keys()
config_key = llm_type_config_keys.get(llm_type)
conf = load_yaml_config(_get_config_file_path())
llm_max_token = conf.get(config_key, {}).get("token_limit")
return llm_max_token
# In the future, we will use reasoning_llm and vl_llm for different purposes
# reasoning_llm = get_llm_by_type("reasoning")
# vl_llm = get_llm_by_type("vision")

View File

@@ -0,0 +1,212 @@
# src/tools/search_postprocessor.py
import re
import base64
import logging
from typing import List, Dict, Any
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
class SearchResultPostProcessor:
"""Search result post-processor"""
base64_pattern = r"data:image/[^;]+;base64,[a-zA-Z0-9+/=]+"
def __init__(self, min_score_threshold: float, max_content_length_per_page: int):
"""
Initialize the post-processor
Args:
min_score_threshold: Minimum relevance score threshold
max_content_length_per_page: Maximum content length
"""
self.min_score_threshold = min_score_threshold
self.max_content_length_per_page = max_content_length_per_page
def process_results(self, results: List[Dict]) -> List[Dict]:
"""
Process search results
Args:
results: Original search result list
Returns:
Processed result list
"""
if not results:
return []
# Combined processing in a single loop for efficiency
cleaned_results = []
seen_urls = set()
for result in results:
# 1. Remove duplicates
cleaned_result = self._remove_duplicates(result, seen_urls)
if not cleaned_result:
continue
# 2. Filter low quality results
if (
"page" == cleaned_result.get("type")
and self.min_score_threshold
and self.min_score_threshold > 0
and cleaned_result.get("score", 0) < self.min_score_threshold
):
continue
# 3. Clean base64 images from content
cleaned_result = self._remove_base64_images(cleaned_result)
if not cleaned_result:
continue
# 4. When max_content_length_per_page is set, truncate long content
if (
self.max_content_length_per_page
and self.max_content_length_per_page > 0
):
cleaned_result = self._truncate_long_content(cleaned_result)
if cleaned_result:
cleaned_results.append(cleaned_result)
# 5. Sort (by score descending)
sorted_results = sorted(
cleaned_results, key=lambda x: x.get("score", 0), reverse=True
)
logger.info(
f"Search result post-processing: {len(results)} -> {len(sorted_results)}"
)
return sorted_results
def _remove_base64_images(self, result: Dict) -> Dict:
"""Remove base64 encoded images from content"""
if "page" == result.get("type"):
cleaned_result = self.processPage(result)
elif "image" == result.get("type"):
cleaned_result = self.processImage(result)
else:
# For other types, keep as is
cleaned_result = result.copy()
return cleaned_result
def processPage(self, result: Dict) -> Dict:
"""Process page type result"""
# Clean base64 images from content
cleaned_result = result.copy()
if "content" in result:
original_content = result["content"]
cleaned_content = re.sub(self.base64_pattern, " ", original_content)
cleaned_result["content"] = cleaned_content
# Log if significant content was removed
if len(cleaned_content) < len(original_content) * 0.8:
logger.debug(
f"Removed base64 images from search content: {result.get('url', 'unknown')}"
)
# Clean base64 images from raw content
if "raw_content" in cleaned_result:
original_raw_content = cleaned_result["raw_content"]
cleaned_raw_content = re.sub(self.base64_pattern, " ", original_raw_content)
cleaned_result["raw_content"] = cleaned_raw_content
# Log if significant content was removed
if len(cleaned_raw_content) < len(original_raw_content) * 0.8:
logger.debug(
f"Removed base64 images from search raw content: {result.get('url', 'unknown')}"
)
return cleaned_result
def processImage(self, result: Dict) -> Dict:
"""Process image type result - clean up base64 data and long fields"""
cleaned_result = result.copy()
# Remove base64 encoded data from image_url if present
if "image_url" in cleaned_result and isinstance(
cleaned_result["image_url"], str
):
# Check if image_url contains base64 data
if "data:image" in cleaned_result["image_url"]:
original_image_url = cleaned_result["image_url"]
cleaned_image_url = re.sub(self.base64_pattern, " ", original_image_url)
if len(cleaned_image_url) == 0 or not cleaned_image_url.startswith(
"http"
):
logger.debug(
f"Removed base64 data from image_url and the cleaned_image_url is empty or not start with http, origin image_url: {result.get('image_url', 'unknown')}"
)
return {}
cleaned_result["image_url"] = cleaned_image_url
logger.debug(
f"Removed base64 data from image_url: {result.get('image_url', 'unknown')}"
)
# Truncate very long image descriptions
if "image_description" in cleaned_result and isinstance(
cleaned_result["image_description"], str
):
if (
self.max_content_length_per_page
and len(cleaned_result["image_description"])
> self.max_content_length_per_page
):
cleaned_result["image_description"] = (
cleaned_result["image_description"][
: self.max_content_length_per_page
]
+ "..."
)
logger.info(
f"Truncated long image description from search result: {result.get('image_url', 'unknown')}"
)
return cleaned_result
def _truncate_long_content(self, result: Dict) -> Dict:
"""Truncate long content"""
truncated_result = result.copy()
# Truncate content length
if "content" in truncated_result:
content = truncated_result["content"]
if len(content) > self.max_content_length_per_page:
truncated_result["content"] = (
content[: self.max_content_length_per_page] + "..."
)
logger.info(
f"Truncated long content from search result: {result.get('url', 'unknown')}"
)
# Truncate raw content length (can be slightly longer)
if "raw_content" in truncated_result:
raw_content = truncated_result["raw_content"]
if len(raw_content) > self.max_content_length_per_page * 2:
truncated_result["raw_content"] = (
raw_content[: self.max_content_length_per_page * 2] + "..."
)
logger.info(
f"Truncated long raw content from search result: {result.get('url', 'unknown')}"
)
return truncated_result
def _remove_duplicates(self, result: Dict, seen_urls: set) -> Dict:
"""Remove duplicate results"""
url = result.get("url", result.get("image_url", ""))
if url and url not in seen_urls:
seen_urls.add(url)
return result.copy() # Return a copy to avoid modifying original
elif not url:
# Keep results with empty URLs
return result.copy() # Return a copy to avoid modifying original
return {} # Return empty dict for duplicates

View File

@@ -11,6 +11,14 @@ from langchain_tavily._utilities import TAVILY_API_URL
from langchain_tavily.tavily_search import (
TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper,
)
from src.tools.search_postprocessor import SearchResultPostProcessor
from src.config import load_yaml_config
def get_search_config():
config = load_yaml_config("conf.yaml")
search_config = config.get("SEARCH_ENGINE", {})
return search_config
class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
@@ -110,4 +118,13 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
"image_description": image["description"],
}
clean_results.append(clean_result)
search_config = get_search_config()
clean_results = SearchResultPostProcessor(
min_score_threshold=search_config.get("min_score_threshold"),
max_content_length_per_page=search_config.get(
"max_content_length_per_page"
),
).process_results(clean_results)
return clean_results

View File

@@ -0,0 +1,265 @@
# src/utils/token_manager.py
from typing import List
from langchain_core.messages import (
BaseMessage,
HumanMessage,
AIMessage,
ToolMessage,
SystemMessage,
)
import logging
import copy
from src.config import load_yaml_config
logger = logging.getLogger(__name__)
def get_search_config():
config = load_yaml_config("conf.yaml")
search_config = config.get("MODEL_TOKEN_LIMITS", {})
return search_config
class ContextManager:
"""Context manager and compression class"""
def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0):
"""
Initialize ContextManager
Args:
token_limit: Maximum token limit
preserve_prefix_message_count: Number of messages to preserve at the beginning of the context
"""
self.token_limit = token_limit
self.preserve_prefix_message_count = preserve_prefix_message_count
def count_tokens(self, messages: List[BaseMessage]) -> int:
"""
Count tokens in message list
Args:
messages: List of messages
Returns:
Number of tokens
"""
total_tokens = 0
for message in messages:
total_tokens += self._count_message_tokens(message)
return total_tokens
def _count_message_tokens(self, message: BaseMessage) -> int:
"""
Count tokens in a single message
Args:
message: Message object
Returns:
Number of tokens
"""
# Estimate token count based on character length (different calculation for English and non-English)
token_count = 0
# Count tokens in content field
if hasattr(message, "content") and message.content:
# Handle different content types
if isinstance(message.content, str):
token_count += self._count_text_tokens(message.content)
# Count role-related tokens
if hasattr(message, "type"):
token_count += self._count_text_tokens(message.type)
# Special handling for different message types
if isinstance(message, SystemMessage):
# System messages are usually short but important, slightly increase estimate
token_count = int(token_count * 1.1)
elif isinstance(message, HumanMessage):
# Human messages use normal estimation
pass
elif isinstance(message, AIMessage):
# AI messages may contain reasoning content, slightly increase estimate
token_count = int(token_count * 1.2)
elif isinstance(message, ToolMessage):
# Tool messages may contain large amounts of structured data, increase estimate
token_count = int(token_count * 1.3)
# Process additional information in additional_kwargs
if hasattr(message, "additional_kwargs") and message.additional_kwargs:
# Simple estimation of extra field tokens
extra_str = str(message.additional_kwargs)
token_count += self._count_text_tokens(extra_str)
# If there are tool_calls, add estimation
if "tool_calls" in message.additional_kwargs:
token_count += 50 # Add estimation for function call information
# Ensure at least 1 token
return max(1, token_count)
def _count_text_tokens(self, text: str) -> int:
"""
Count tokens in text with different calculations for English and non-English characters.
English characters: 4 characters ≈ 1 token
Non-English characters (e.g., Chinese): 1 character ≈ 1 token
Args:
text: Text to count tokens for
Returns:
Number of tokens
"""
if not text:
return 0
english_chars = 0
non_english_chars = 0
for char in text:
# Check if character is ASCII (English letters, digits, punctuation)
if ord(char) < 128:
english_chars += 1
else:
non_english_chars += 1
# Calculate tokens: English at 4 chars/token, others at 1 char/token
english_tokens = english_chars // 4
non_english_tokens = non_english_chars
return english_tokens + non_english_tokens
def is_over_limit(self, messages: List[BaseMessage]) -> bool:
"""
Check if messages exceed token limit
Args:
messages: List of messages
Returns:
Whether limit is exceeded
"""
return self.count_tokens(messages) > self.token_limit
def compress_messages(self, state: dict) -> List[BaseMessage]:
"""
Compress messages to fit within token limit
Args:
state: state with original messages
Returns:
Compressed state with compressed messages
"""
# If not set token_limit, return original state
if self.token_limit is None:
logger.info("No token_limit set, the context management doesn't work.")
return state
if not isinstance(state, dict) or "messages" not in state:
logger.warning("No messages found in state")
return state
messages = state["messages"]
if not self.is_over_limit(messages):
return state
# 2. Compress messages
compressed_messages = self._compress_messages(messages)
logger.info(
f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens"
)
state["messages"] = compressed_messages
return state
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
"""
Compress compressible messages
Args:
messages: List of messages to compress
Returns:
Compressed message list
"""
available_token = self.token_limit
prefix_messages = []
# 1. Preserve head messages of specified length to retain system prompts and user input
for i in range(min(self.preserve_prefix_message_count, len(messages))):
cur_token_cnt = self._count_message_tokens(messages[i])
if available_token > 0 and available_token >= cur_token_cnt:
prefix_messages.append(messages[i])
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
messages[i], available_token
)
prefix_messages.append(truncated_message)
return prefix_messages
else:
break
# 2. Compress subsequent messages from the tail, some messages may be discarded
messages = messages[len(prefix_messages) :]
suffix_messages = []
for i in range(len(messages) - 1, -1, -1):
cur_token_cnt = self._count_message_tokens(messages[i])
if cur_token_cnt > 0 and available_token >= cur_token_cnt:
suffix_messages = [messages[i]] + suffix_messages
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
messages[i], available_token
)
suffix_messages = [truncated_message] + suffix_messages
return prefix_messages + suffix_messages
else:
break
return prefix_messages + suffix_messages
def _truncate_message_content(
self, message: BaseMessage, max_tokens: int
) -> BaseMessage:
"""
Truncate message content while preserving all other attributes by copying the original message
and only modifying its content attribute.
Args:
message: The message to truncate
max_tokens: Maximum number of tokens to keep
Returns:
New message instance with truncated content
"""
# Create a deep copy of the original message to preserve all attributes
truncated_message = copy.deepcopy(message)
# Truncate only the content attribute
truncated_message.content = message.content[:max_tokens]
return truncated_message
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
"""
Create summary for messages
Args:
messages: Messages to summarize
Returns:
Summary message
"""
# TODO: summary implementation
pass