mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 04:14:46 +08:00
feat: add context compress (#590)
* feat:Add context compress * feat: Add unit test * feat: add unit test for context manager * feat: add postprocessor param && code format * feat: add configuration guide * fix: fix the configuration_guide * fix: fix the unit test * fix: fix the default value * feat: add test and log for context_manager
This commit is contained in:
@@ -180,6 +180,20 @@ BASIC_MODEL:
|
|||||||
api_key: $AZURE_OPENAI_API_KEY
|
api_key: $AZURE_OPENAI_API_KEY
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### How to configure context length for different models
|
||||||
|
|
||||||
|
Different models have different context length limitations. DeerFlow provides a method to control the context length between different models. You can configure the context length between different models in the `conf.yaml` file. For example:
|
||||||
|
```yaml
|
||||||
|
BASIC_MODEL:
|
||||||
|
base_url: https://ark.cn-beijing.volces.com/api/v3
|
||||||
|
model: "doubao-1-5-pro-32k-250115"
|
||||||
|
api_key: ""
|
||||||
|
token_limit: 128000
|
||||||
|
```
|
||||||
|
This means that the context length limit using this model is 128k.
|
||||||
|
|
||||||
|
The context management doesn't work if the token_limit is not set.
|
||||||
|
|
||||||
## About Search Engine
|
## About Search Engine
|
||||||
|
|
||||||
### How to control search domains for Tavily?
|
### How to control search domains for Tavily?
|
||||||
@@ -210,6 +224,28 @@ SEARCH_ENGINE:
|
|||||||
include_raw_content: false
|
include_raw_content: false
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### How to post-process Tavily search results
|
||||||
|
|
||||||
|
DeerFlow can post-process Tavily search results:
|
||||||
|
* Remove duplicate content
|
||||||
|
* Filter low-quality content: Filter out results with low relevance scores
|
||||||
|
* Clear base64 encoded images
|
||||||
|
* Length truncation: Truncate each search result according to the user-configured length
|
||||||
|
|
||||||
|
The filtering of low-quality content and length truncation depend on user configuration, providing two configurable parameters:
|
||||||
|
* min_score_threshold: Minimum relevance score threshold, search results below this threshold will be filtered. If not set, no filtering will be performed;
|
||||||
|
* max_content_length_per_page: Maximum length limit for each search result content, parts exceeding this length will be truncated. If not set, no truncation will be performed;
|
||||||
|
|
||||||
|
These two parameters can be configured in `conf.yaml` as shown below:
|
||||||
|
```yaml
|
||||||
|
SEARCH_ENGINE:
|
||||||
|
engine: tavily
|
||||||
|
include_images: true
|
||||||
|
min_score_threshold: 0.4
|
||||||
|
max_content_length_per_page: 5000
|
||||||
|
```
|
||||||
|
That's meaning that the search results will be filtered based on the minimum relevance score threshold and truncated to the maximum length limit for each search result content.
|
||||||
|
|
||||||
## RAG (Retrieval-Augmented Generation) Configuration
|
## RAG (Retrieval-Augmented Generation) Configuration
|
||||||
|
|
||||||
DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables.
|
DeerFlow supports multiple RAG providers for document retrieval. Configure the RAG provider by setting environment variables.
|
||||||
@@ -244,4 +280,4 @@ MILVUS_EMBEDDING_PROVIDER=openai
|
|||||||
MILVUS_EMBEDDING_BASE_URL=
|
MILVUS_EMBEDDING_BASE_URL=
|
||||||
MILVUS_EMBEDDING_MODEL=
|
MILVUS_EMBEDDING_MODEL=
|
||||||
MILVUS_EMBEDDING_API_KEY=
|
MILVUS_EMBEDDING_API_KEY=
|
||||||
```
|
```
|
||||||
@@ -9,11 +9,18 @@ from src.prompts import apply_prompt_template
|
|||||||
|
|
||||||
|
|
||||||
# Create agents using configured LLM types
|
# Create agents using configured LLM types
|
||||||
def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template: str):
|
def create_agent(
|
||||||
|
agent_name: str,
|
||||||
|
agent_type: str,
|
||||||
|
tools: list,
|
||||||
|
prompt_template: str,
|
||||||
|
pre_model_hook: callable = None,
|
||||||
|
):
|
||||||
"""Factory function to create agents with consistent configuration."""
|
"""Factory function to create agents with consistent configuration."""
|
||||||
return create_react_agent(
|
return create_react_agent(
|
||||||
name=agent_name,
|
name=agent_name,
|
||||||
model=get_llm_by_type(AGENT_LLM_MAP[agent_type]),
|
model=get_llm_by_type(AGENT_LLM_MAP[agent_type]),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
prompt=lambda state: apply_prompt_template(prompt_template, state),
|
prompt=lambda state: apply_prompt_template(prompt_template, state),
|
||||||
|
pre_model_hook=pre_model_hook,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,16 +6,17 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
from langgraph.types import Command, interrupt
|
from langgraph.types import Command, interrupt
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from src.agents import create_agent
|
from src.agents import create_agent
|
||||||
from src.config.agents import AGENT_LLM_MAP
|
from src.config.agents import AGENT_LLM_MAP
|
||||||
from src.config.configuration import Configuration
|
from src.config.configuration import Configuration
|
||||||
from src.llms.llm import get_llm_by_type
|
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
|
||||||
from src.prompts.planner_model import Plan
|
from src.prompts.planner_model import Plan
|
||||||
from src.prompts.template import apply_prompt_template
|
from src.prompts.template import apply_prompt_template
|
||||||
from src.tools import (
|
from src.tools import (
|
||||||
@@ -26,6 +27,7 @@ from src.tools import (
|
|||||||
)
|
)
|
||||||
from src.tools.search import LoggedTavilySearch
|
from src.tools.search import LoggedTavilySearch
|
||||||
from src.utils.json_utils import repair_json_output
|
from src.utils.json_utils import repair_json_output
|
||||||
|
from src.utils.context_manager import ContextManager
|
||||||
|
|
||||||
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||||
from .types import State
|
from .types import State
|
||||||
@@ -283,13 +285,22 @@ def reporter_node(state: State, config: RunnableConfig):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
observation_messages = []
|
||||||
for observation in observations:
|
for observation in observations:
|
||||||
invoke_messages.append(
|
observation_messages.append(
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=f"Below are some observations for the research task:\n\n{observation}",
|
content=f"Below are some observations for the research task:\n\n{observation}",
|
||||||
name="observation",
|
name="observation",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Context compression
|
||||||
|
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP["reporter"])
|
||||||
|
compressed_state = ContextManager(llm_token_limit).compress_messages(
|
||||||
|
{"messages": observation_messages}
|
||||||
|
)
|
||||||
|
invoke_messages += compressed_state.get("messages", [])
|
||||||
|
|
||||||
logger.debug(f"Current invoke messages: {invoke_messages}")
|
logger.debug(f"Current invoke messages: {invoke_messages}")
|
||||||
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
|
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
|
||||||
response_content = response.content
|
response_content = response.content
|
||||||
@@ -469,11 +480,20 @@ async def _setup_and_execute_agent_step(
|
|||||||
f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}"
|
f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}"
|
||||||
)
|
)
|
||||||
loaded_tools.append(tool)
|
loaded_tools.append(tool)
|
||||||
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
|
|
||||||
|
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
|
||||||
|
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
|
||||||
|
agent = create_agent(
|
||||||
|
agent_type, agent_type, loaded_tools, agent_type, pre_model_hook
|
||||||
|
)
|
||||||
return await _execute_agent_step(state, agent, agent_type)
|
return await _execute_agent_step(state, agent, agent_type)
|
||||||
else:
|
else:
|
||||||
# Use default tools if no MCP servers are configured
|
# Use default tools if no MCP servers are configured
|
||||||
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
|
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
|
||||||
|
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
|
||||||
|
agent = create_agent(
|
||||||
|
agent_type, agent_type, default_tools, agent_type, pre_model_hook
|
||||||
|
)
|
||||||
return await _execute_agent_step(state, agent, agent_type)
|
return await _execute_agent_step(state, agent, agent_type)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,10 @@ def _create_llm_use_conf(llm_type: LLMType, conf: Dict[str, Any]) -> BaseChatMod
|
|||||||
# Merge configurations, with environment variables taking precedence
|
# Merge configurations, with environment variables taking precedence
|
||||||
merged_conf = {**llm_conf, **env_conf}
|
merged_conf = {**llm_conf, **env_conf}
|
||||||
|
|
||||||
|
# Remove unnecessary parameters when initializing the client
|
||||||
|
if "token_limit" in merged_conf:
|
||||||
|
merged_conf.pop("token_limit")
|
||||||
|
|
||||||
if not merged_conf:
|
if not merged_conf:
|
||||||
raise ValueError(f"No configuration found for LLM type: {llm_type}")
|
raise ValueError(f"No configuration found for LLM type: {llm_type}")
|
||||||
|
|
||||||
@@ -174,6 +178,25 @@ def get_configured_llm_models() -> dict[str, list[str]]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_token_limit_by_type(llm_type: str) -> int:
|
||||||
|
"""
|
||||||
|
Get the maximum token limit for a given LLM type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_type (str): The type of LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The maximum token limit for the specified LLM type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm_type_config_keys = _get_llm_type_config_keys()
|
||||||
|
config_key = llm_type_config_keys.get(llm_type)
|
||||||
|
|
||||||
|
conf = load_yaml_config(_get_config_file_path())
|
||||||
|
llm_max_token = conf.get(config_key, {}).get("token_limit")
|
||||||
|
return llm_max_token
|
||||||
|
|
||||||
|
|
||||||
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
# In the future, we will use reasoning_llm and vl_llm for different purposes
|
||||||
# reasoning_llm = get_llm_by_type("reasoning")
|
# reasoning_llm = get_llm_by_type("reasoning")
|
||||||
# vl_llm = get_llm_by_type("vision")
|
# vl_llm = get_llm_by_type("vision")
|
||||||
|
|||||||
212
src/tools/search_postprocessor.py
Normal file
212
src/tools/search_postprocessor.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
# src/tools/search_postprocessor.py
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResultPostProcessor:
|
||||||
|
"""Search result post-processor"""
|
||||||
|
|
||||||
|
base64_pattern = r"data:image/[^;]+;base64,[a-zA-Z0-9+/=]+"
|
||||||
|
|
||||||
|
def __init__(self, min_score_threshold: float, max_content_length_per_page: int):
|
||||||
|
"""
|
||||||
|
Initialize the post-processor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_score_threshold: Minimum relevance score threshold
|
||||||
|
max_content_length_per_page: Maximum content length
|
||||||
|
"""
|
||||||
|
self.min_score_threshold = min_score_threshold
|
||||||
|
self.max_content_length_per_page = max_content_length_per_page
|
||||||
|
|
||||||
|
def process_results(self, results: List[Dict]) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Process search results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Original search result list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed result list
|
||||||
|
"""
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Combined processing in a single loop for efficiency
|
||||||
|
cleaned_results = []
|
||||||
|
seen_urls = set()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
# 1. Remove duplicates
|
||||||
|
cleaned_result = self._remove_duplicates(result, seen_urls)
|
||||||
|
if not cleaned_result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. Filter low quality results
|
||||||
|
if (
|
||||||
|
"page" == cleaned_result.get("type")
|
||||||
|
and self.min_score_threshold
|
||||||
|
and self.min_score_threshold > 0
|
||||||
|
and cleaned_result.get("score", 0) < self.min_score_threshold
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. Clean base64 images from content
|
||||||
|
cleaned_result = self._remove_base64_images(cleaned_result)
|
||||||
|
if not cleaned_result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 4. When max_content_length_per_page is set, truncate long content
|
||||||
|
if (
|
||||||
|
self.max_content_length_per_page
|
||||||
|
and self.max_content_length_per_page > 0
|
||||||
|
):
|
||||||
|
cleaned_result = self._truncate_long_content(cleaned_result)
|
||||||
|
|
||||||
|
if cleaned_result:
|
||||||
|
cleaned_results.append(cleaned_result)
|
||||||
|
|
||||||
|
# 5. Sort (by score descending)
|
||||||
|
sorted_results = sorted(
|
||||||
|
cleaned_results, key=lambda x: x.get("score", 0), reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Search result post-processing: {len(results)} -> {len(sorted_results)}"
|
||||||
|
)
|
||||||
|
return sorted_results
|
||||||
|
|
||||||
|
def _remove_base64_images(self, result: Dict) -> Dict:
|
||||||
|
"""Remove base64 encoded images from content"""
|
||||||
|
|
||||||
|
if "page" == result.get("type"):
|
||||||
|
cleaned_result = self.processPage(result)
|
||||||
|
elif "image" == result.get("type"):
|
||||||
|
cleaned_result = self.processImage(result)
|
||||||
|
else:
|
||||||
|
# For other types, keep as is
|
||||||
|
cleaned_result = result.copy()
|
||||||
|
|
||||||
|
return cleaned_result
|
||||||
|
|
||||||
|
def processPage(self, result: Dict) -> Dict:
|
||||||
|
"""Process page type result"""
|
||||||
|
# Clean base64 images from content
|
||||||
|
cleaned_result = result.copy()
|
||||||
|
|
||||||
|
if "content" in result:
|
||||||
|
original_content = result["content"]
|
||||||
|
cleaned_content = re.sub(self.base64_pattern, " ", original_content)
|
||||||
|
cleaned_result["content"] = cleaned_content
|
||||||
|
|
||||||
|
# Log if significant content was removed
|
||||||
|
if len(cleaned_content) < len(original_content) * 0.8:
|
||||||
|
logger.debug(
|
||||||
|
f"Removed base64 images from search content: {result.get('url', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean base64 images from raw content
|
||||||
|
if "raw_content" in cleaned_result:
|
||||||
|
original_raw_content = cleaned_result["raw_content"]
|
||||||
|
cleaned_raw_content = re.sub(self.base64_pattern, " ", original_raw_content)
|
||||||
|
cleaned_result["raw_content"] = cleaned_raw_content
|
||||||
|
|
||||||
|
# Log if significant content was removed
|
||||||
|
if len(cleaned_raw_content) < len(original_raw_content) * 0.8:
|
||||||
|
logger.debug(
|
||||||
|
f"Removed base64 images from search raw content: {result.get('url', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cleaned_result
|
||||||
|
|
||||||
|
def processImage(self, result: Dict) -> Dict:
|
||||||
|
"""Process image type result - clean up base64 data and long fields"""
|
||||||
|
cleaned_result = result.copy()
|
||||||
|
|
||||||
|
# Remove base64 encoded data from image_url if present
|
||||||
|
if "image_url" in cleaned_result and isinstance(
|
||||||
|
cleaned_result["image_url"], str
|
||||||
|
):
|
||||||
|
# Check if image_url contains base64 data
|
||||||
|
if "data:image" in cleaned_result["image_url"]:
|
||||||
|
original_image_url = cleaned_result["image_url"]
|
||||||
|
cleaned_image_url = re.sub(self.base64_pattern, " ", original_image_url)
|
||||||
|
if len(cleaned_image_url) == 0 or not cleaned_image_url.startswith(
|
||||||
|
"http"
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f"Removed base64 data from image_url and the cleaned_image_url is empty or not start with http, origin image_url: {result.get('image_url', 'unknown')}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
cleaned_result["image_url"] = cleaned_image_url
|
||||||
|
logger.debug(
|
||||||
|
f"Removed base64 data from image_url: {result.get('image_url', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate very long image descriptions
|
||||||
|
if "image_description" in cleaned_result and isinstance(
|
||||||
|
cleaned_result["image_description"], str
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
self.max_content_length_per_page
|
||||||
|
and len(cleaned_result["image_description"])
|
||||||
|
> self.max_content_length_per_page
|
||||||
|
):
|
||||||
|
cleaned_result["image_description"] = (
|
||||||
|
cleaned_result["image_description"][
|
||||||
|
: self.max_content_length_per_page
|
||||||
|
]
|
||||||
|
+ "..."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Truncated long image description from search result: {result.get('image_url', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cleaned_result
|
||||||
|
|
||||||
|
def _truncate_long_content(self, result: Dict) -> Dict:
|
||||||
|
"""Truncate long content"""
|
||||||
|
|
||||||
|
truncated_result = result.copy()
|
||||||
|
|
||||||
|
# Truncate content length
|
||||||
|
if "content" in truncated_result:
|
||||||
|
content = truncated_result["content"]
|
||||||
|
if len(content) > self.max_content_length_per_page:
|
||||||
|
truncated_result["content"] = (
|
||||||
|
content[: self.max_content_length_per_page] + "..."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Truncated long content from search result: {result.get('url', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate raw content length (can be slightly longer)
|
||||||
|
if "raw_content" in truncated_result:
|
||||||
|
raw_content = truncated_result["raw_content"]
|
||||||
|
if len(raw_content) > self.max_content_length_per_page * 2:
|
||||||
|
truncated_result["raw_content"] = (
|
||||||
|
raw_content[: self.max_content_length_per_page * 2] + "..."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Truncated long raw content from search result: {result.get('url', 'unknown')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return truncated_result
|
||||||
|
|
||||||
|
def _remove_duplicates(self, result: Dict, seen_urls: set) -> Dict:
|
||||||
|
"""Remove duplicate results"""
|
||||||
|
|
||||||
|
url = result.get("url", result.get("image_url", ""))
|
||||||
|
if url and url not in seen_urls:
|
||||||
|
seen_urls.add(url)
|
||||||
|
return result.copy() # Return a copy to avoid modifying original
|
||||||
|
elif not url:
|
||||||
|
# Keep results with empty URLs
|
||||||
|
return result.copy() # Return a copy to avoid modifying original
|
||||||
|
|
||||||
|
return {} # Return empty dict for duplicates
|
||||||
@@ -11,6 +11,14 @@ from langchain_tavily._utilities import TAVILY_API_URL
|
|||||||
from langchain_tavily.tavily_search import (
|
from langchain_tavily.tavily_search import (
|
||||||
TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper,
|
TavilySearchAPIWrapper as OriginalTavilySearchAPIWrapper,
|
||||||
)
|
)
|
||||||
|
from src.tools.search_postprocessor import SearchResultPostProcessor
|
||||||
|
from src.config import load_yaml_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_search_config():
|
||||||
|
config = load_yaml_config("conf.yaml")
|
||||||
|
search_config = config.get("SEARCH_ENGINE", {})
|
||||||
|
return search_config
|
||||||
|
|
||||||
|
|
||||||
class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
|
class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
|
||||||
@@ -110,4 +118,13 @@ class EnhancedTavilySearchAPIWrapper(OriginalTavilySearchAPIWrapper):
|
|||||||
"image_description": image["description"],
|
"image_description": image["description"],
|
||||||
}
|
}
|
||||||
clean_results.append(clean_result)
|
clean_results.append(clean_result)
|
||||||
|
|
||||||
|
search_config = get_search_config()
|
||||||
|
clean_results = SearchResultPostProcessor(
|
||||||
|
min_score_threshold=search_config.get("min_score_threshold"),
|
||||||
|
max_content_length_per_page=search_config.get(
|
||||||
|
"max_content_length_per_page"
|
||||||
|
),
|
||||||
|
).process_results(clean_results)
|
||||||
|
|
||||||
return clean_results
|
return clean_results
|
||||||
|
|||||||
265
src/utils/context_manager.py
Normal file
265
src/utils/context_manager.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
# src/utils/token_manager.py
|
||||||
|
from typing import List
|
||||||
|
from langchain_core.messages import (
|
||||||
|
BaseMessage,
|
||||||
|
HumanMessage,
|
||||||
|
AIMessage,
|
||||||
|
ToolMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
import logging
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from src.config import load_yaml_config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_search_config():
|
||||||
|
config = load_yaml_config("conf.yaml")
|
||||||
|
search_config = config.get("MODEL_TOKEN_LIMITS", {})
|
||||||
|
return search_config
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManager:
|
||||||
|
"""Context manager and compression class"""
|
||||||
|
|
||||||
|
def __init__(self, token_limit: int, preserve_prefix_message_count: int = 0):
|
||||||
|
"""
|
||||||
|
Initialize ContextManager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_limit: Maximum token limit
|
||||||
|
preserve_prefix_message_count: Number of messages to preserve at the beginning of the context
|
||||||
|
"""
|
||||||
|
self.token_limit = token_limit
|
||||||
|
self.preserve_prefix_message_count = preserve_prefix_message_count
|
||||||
|
|
||||||
|
def count_tokens(self, messages: List[BaseMessage]) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in message list
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens
|
||||||
|
"""
|
||||||
|
total_tokens = 0
|
||||||
|
for message in messages:
|
||||||
|
total_tokens += self._count_message_tokens(message)
|
||||||
|
return total_tokens
|
||||||
|
|
||||||
|
def _count_message_tokens(self, message: BaseMessage) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in a single message
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens
|
||||||
|
"""
|
||||||
|
# Estimate token count based on character length (different calculation for English and non-English)
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
# Count tokens in content field
|
||||||
|
if hasattr(message, "content") and message.content:
|
||||||
|
# Handle different content types
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
token_count += self._count_text_tokens(message.content)
|
||||||
|
|
||||||
|
# Count role-related tokens
|
||||||
|
if hasattr(message, "type"):
|
||||||
|
token_count += self._count_text_tokens(message.type)
|
||||||
|
|
||||||
|
# Special handling for different message types
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
|
# System messages are usually short but important, slightly increase estimate
|
||||||
|
token_count = int(token_count * 1.1)
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
# Human messages use normal estimation
|
||||||
|
pass
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
# AI messages may contain reasoning content, slightly increase estimate
|
||||||
|
token_count = int(token_count * 1.2)
|
||||||
|
elif isinstance(message, ToolMessage):
|
||||||
|
# Tool messages may contain large amounts of structured data, increase estimate
|
||||||
|
token_count = int(token_count * 1.3)
|
||||||
|
|
||||||
|
# Process additional information in additional_kwargs
|
||||||
|
if hasattr(message, "additional_kwargs") and message.additional_kwargs:
|
||||||
|
# Simple estimation of extra field tokens
|
||||||
|
extra_str = str(message.additional_kwargs)
|
||||||
|
token_count += self._count_text_tokens(extra_str)
|
||||||
|
|
||||||
|
# If there are tool_calls, add estimation
|
||||||
|
if "tool_calls" in message.additional_kwargs:
|
||||||
|
token_count += 50 # Add estimation for function call information
|
||||||
|
|
||||||
|
# Ensure at least 1 token
|
||||||
|
return max(1, token_count)
|
||||||
|
|
||||||
|
def _count_text_tokens(self, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in text with different calculations for English and non-English characters.
|
||||||
|
English characters: 4 characters ≈ 1 token
|
||||||
|
Non-English characters (e.g., Chinese): 1 character ≈ 1 token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
english_chars = 0
|
||||||
|
non_english_chars = 0
|
||||||
|
|
||||||
|
for char in text:
|
||||||
|
# Check if character is ASCII (English letters, digits, punctuation)
|
||||||
|
if ord(char) < 128:
|
||||||
|
english_chars += 1
|
||||||
|
else:
|
||||||
|
non_english_chars += 1
|
||||||
|
|
||||||
|
# Calculate tokens: English at 4 chars/token, others at 1 char/token
|
||||||
|
english_tokens = english_chars // 4
|
||||||
|
non_english_tokens = non_english_chars
|
||||||
|
|
||||||
|
return english_tokens + non_english_tokens
|
||||||
|
|
||||||
|
def is_over_limit(self, messages: List[BaseMessage]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if messages exceed token limit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Whether limit is exceeded
|
||||||
|
"""
|
||||||
|
return self.count_tokens(messages) > self.token_limit
|
||||||
|
|
||||||
|
def compress_messages(self, state: dict) -> List[BaseMessage]:
|
||||||
|
"""
|
||||||
|
Compress messages to fit within token limit
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: state with original messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed state with compressed messages
|
||||||
|
"""
|
||||||
|
# If not set token_limit, return original state
|
||||||
|
if self.token_limit is None:
|
||||||
|
logger.info("No token_limit set, the context management doesn't work.")
|
||||||
|
return state
|
||||||
|
|
||||||
|
if not isinstance(state, dict) or "messages" not in state:
|
||||||
|
logger.warning("No messages found in state")
|
||||||
|
return state
|
||||||
|
|
||||||
|
messages = state["messages"]
|
||||||
|
|
||||||
|
if not self.is_over_limit(messages):
|
||||||
|
return state
|
||||||
|
|
||||||
|
# 2. Compress messages
|
||||||
|
compressed_messages = self._compress_messages(messages)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
state["messages"] = compressed_messages
|
||||||
|
return state
|
||||||
|
|
||||||
|
def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||||
|
"""
|
||||||
|
Compress compressible messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages to compress
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed message list
|
||||||
|
"""
|
||||||
|
|
||||||
|
available_token = self.token_limit
|
||||||
|
prefix_messages = []
|
||||||
|
|
||||||
|
# 1. Preserve head messages of specified length to retain system prompts and user input
|
||||||
|
for i in range(min(self.preserve_prefix_message_count, len(messages))):
|
||||||
|
cur_token_cnt = self._count_message_tokens(messages[i])
|
||||||
|
if available_token > 0 and available_token >= cur_token_cnt:
|
||||||
|
prefix_messages.append(messages[i])
|
||||||
|
available_token -= cur_token_cnt
|
||||||
|
elif available_token > 0:
|
||||||
|
# Truncate content to fit available tokens
|
||||||
|
truncated_message = self._truncate_message_content(
|
||||||
|
messages[i], available_token
|
||||||
|
)
|
||||||
|
prefix_messages.append(truncated_message)
|
||||||
|
return prefix_messages
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 2. Compress subsequent messages from the tail, some messages may be discarded
|
||||||
|
messages = messages[len(prefix_messages) :]
|
||||||
|
suffix_messages = []
|
||||||
|
for i in range(len(messages) - 1, -1, -1):
|
||||||
|
cur_token_cnt = self._count_message_tokens(messages[i])
|
||||||
|
|
||||||
|
if cur_token_cnt > 0 and available_token >= cur_token_cnt:
|
||||||
|
suffix_messages = [messages[i]] + suffix_messages
|
||||||
|
available_token -= cur_token_cnt
|
||||||
|
elif available_token > 0:
|
||||||
|
# Truncate content to fit available tokens
|
||||||
|
truncated_message = self._truncate_message_content(
|
||||||
|
messages[i], available_token
|
||||||
|
)
|
||||||
|
suffix_messages = [truncated_message] + suffix_messages
|
||||||
|
return prefix_messages + suffix_messages
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return prefix_messages + suffix_messages
|
||||||
|
|
||||||
|
def _truncate_message_content(
|
||||||
|
self, message: BaseMessage, max_tokens: int
|
||||||
|
) -> BaseMessage:
|
||||||
|
"""
|
||||||
|
Truncate message content while preserving all other attributes by copying the original message
|
||||||
|
and only modifying its content attribute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to truncate
|
||||||
|
max_tokens: Maximum number of tokens to keep
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New message instance with truncated content
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a deep copy of the original message to preserve all attributes
|
||||||
|
truncated_message = copy.deepcopy(message)
|
||||||
|
|
||||||
|
# Truncate only the content attribute
|
||||||
|
truncated_message.content = message.content[:max_tokens]
|
||||||
|
|
||||||
|
return truncated_message
|
||||||
|
|
||||||
|
def _create_summary_message(self, messages: List[BaseMessage]) -> BaseMessage:
|
||||||
|
"""
|
||||||
|
Create summary for messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Messages to summarize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Summary message
|
||||||
|
"""
|
||||||
|
# TODO: summary implementation
|
||||||
|
pass
|
||||||
262
tests/unit/tools/test_search_postprocessor.py
Normal file
262
tests/unit/tools/test_search_postprocessor.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
import pytest
|
||||||
|
from src.tools.search_postprocessor import SearchResultPostProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestSearchResultPostProcessor:
|
||||||
|
"""Test cases for SearchResultPostProcessor"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def post_processor(self):
|
||||||
|
"""Create a SearchResultPostProcessor instance for testing"""
|
||||||
|
return SearchResultPostProcessor(
|
||||||
|
min_score_threshold=0.5, max_content_length_per_page=100
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_process_results_empty_input(self, post_processor):
|
||||||
|
"""Test processing empty results"""
|
||||||
|
results = []
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert processed == []
|
||||||
|
|
||||||
|
def test_process_results_with_valid_page_results(self, post_processor):
|
||||||
|
"""Test processing valid page results"""
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Test Page",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"content": "Test content",
|
||||||
|
"score": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert processed[0]["title"] == "Test Page"
|
||||||
|
assert processed[0]["url"] == "https://example.com"
|
||||||
|
assert processed[0]["content"] == "Test content"
|
||||||
|
assert processed[0]["score"] == 0.8
|
||||||
|
|
||||||
|
def test_process_results_filter_low_score(self, post_processor):
|
||||||
|
"""Test filtering out low score results"""
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Low Score Page",
|
||||||
|
"url": "https://example.com/low",
|
||||||
|
"content": "Low score content",
|
||||||
|
"score": 0.3, # Below threshold of 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "High Score Page",
|
||||||
|
"url": "https://example.com/high",
|
||||||
|
"content": "High score content",
|
||||||
|
"score": 0.9,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert processed[0]["title"] == "High Score Page"
|
||||||
|
|
||||||
|
def test_process_results_remove_duplicates(self, post_processor):
|
||||||
|
"""Test removing duplicate URLs"""
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Page 1",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"content": "Content 1",
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Page 2",
|
||||||
|
"url": "https://example.com", # Duplicate URL
|
||||||
|
"content": "Content 2",
|
||||||
|
"score": 0.7,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert processed[0]["title"] == "Page 1" # First one should be kept
|
||||||
|
|
||||||
|
def test_process_results_sort_by_score(self, post_processor):
|
||||||
|
"""Test sorting results by score in descending order"""
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Low Score",
|
||||||
|
"url": "https://example.com/low",
|
||||||
|
"content": "Low score content",
|
||||||
|
"score": 0.3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "High Score",
|
||||||
|
"url": "https://example.com/high",
|
||||||
|
"content": "High score content",
|
||||||
|
"score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Medium Score",
|
||||||
|
"url": "https://example.com/medium",
|
||||||
|
"content": "Medium score content",
|
||||||
|
"score": 0.6,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 2 # Low score filtered out
|
||||||
|
# Should be sorted by score descending
|
||||||
|
assert processed[0]["title"] == "High Score"
|
||||||
|
assert processed[1]["title"] == "Medium Score"
|
||||||
|
|
||||||
|
def test_process_results_truncate_long_content(self, post_processor):
|
||||||
|
"""Test truncating long content"""
|
||||||
|
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Long Content Page",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"content": long_content,
|
||||||
|
"score": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert len(processed[0]["content"]) == 103 # 100 + "..."
|
||||||
|
assert processed[0]["content"].endswith("...")
|
||||||
|
|
||||||
|
def test_process_results_remove_base64_images(self, post_processor):
|
||||||
|
"""Test removing base64 images from content"""
|
||||||
|
content_with_base64 = (
|
||||||
|
"Content with image "
|
||||||
|
+ "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
|
||||||
|
)
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Page with Base64",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"content": content_with_base64,
|
||||||
|
"score": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert processed[0]["content"] == "Content with image "
|
||||||
|
|
||||||
|
def test_process_results_with_image_type(self, post_processor):
|
||||||
|
"""Test processing image type results"""
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "https://example.com/image.jpg",
|
||||||
|
"image_description": "Test image",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert processed[0]["type"] == "image"
|
||||||
|
assert processed[0]["image_url"] == "https://example.com/image.jpg"
|
||||||
|
assert processed[0]["image_description"] == "Test image"
|
||||||
|
|
||||||
|
def test_process_results_filter_base64_image_urls(self, post_processor):
|
||||||
|
"""Test filtering out image results with base64 URLs"""
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==",
|
||||||
|
"image_description": "Base64 image",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "https://example.com/image.jpg",
|
||||||
|
"image_description": "Regular image",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert processed[0]["image_url"] == "https://example.com/image.jpg"
|
||||||
|
|
||||||
|
def test_process_results_truncate_long_image_description(self, post_processor):
|
||||||
|
"""Test truncating long image descriptions"""
|
||||||
|
long_description = "A" * 150 # Longer than max_content_length of 100
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "https://example.com/image.jpg",
|
||||||
|
"image_description": long_description,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert len(processed[0]["image_description"]) == 103 # 100 + "..."
|
||||||
|
assert processed[0]["image_description"].endswith("...")
|
||||||
|
|
||||||
|
def test_process_results_other_types_passthrough(self, post_processor):
|
||||||
|
"""Test that other result types pass through unchanged"""
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "video",
|
||||||
|
"title": "Test Video",
|
||||||
|
"url": "https://example.com/video.mp4",
|
||||||
|
"score": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert processed[0]["type"] == "video"
|
||||||
|
assert processed[0]["title"] == "Test Video"
|
||||||
|
|
||||||
|
def test_process_results_truncate_long_content_with_no_config(self):
|
||||||
|
"""Test truncating long content"""
|
||||||
|
post_processor = SearchResultPostProcessor(None, None)
|
||||||
|
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Long Content Page",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"content": long_content,
|
||||||
|
"score": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert len(processed[0]["content"]) == len("A" * 150)
|
||||||
|
|
||||||
|
def test_process_results_truncate_long_content_with_max_content_length_config(self):
|
||||||
|
"""Test truncating long content"""
|
||||||
|
post_processor = SearchResultPostProcessor(None, 100)
|
||||||
|
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Long Content Page",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"content": long_content,
|
||||||
|
"score": 0.8,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 1
|
||||||
|
assert len(processed[0]["content"]) == 103
|
||||||
|
assert processed[0]["content"].endswith("...")
|
||||||
|
|
||||||
|
def test_process_results_truncate_long_content_with_min_score_config(self):
|
||||||
|
"""Test truncating long content"""
|
||||||
|
post_processor = SearchResultPostProcessor(0.8, None)
|
||||||
|
long_content = "A" * 150 # Longer than max_content_length of 100
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"type": "page",
|
||||||
|
"title": "Long Content Page",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"content": long_content,
|
||||||
|
"score": 0.3,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
processed = post_processor.process_results(results)
|
||||||
|
assert len(processed) == 0
|
||||||
183
tests/unit/utils/test_context_manager.py
Normal file
183
tests/unit/utils/test_context_manager.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import pytest
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
||||||
|
from src.utils.context_manager import ContextManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextManager:
|
||||||
|
"""Test cases for ContextManager"""
|
||||||
|
|
||||||
|
def test_count_tokens_with_empty_messages(self):
|
||||||
|
"""Test counting tokens with empty message list"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = []
|
||||||
|
token_count = context_manager.count_tokens(messages)
|
||||||
|
assert token_count == 0
|
||||||
|
|
||||||
|
def test_count_tokens_with_system_message(self):
|
||||||
|
"""Test counting tokens with system message"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = [SystemMessage(content="You are a helpful assistant.")]
|
||||||
|
token_count = context_manager.count_tokens(messages)
|
||||||
|
# System message has 28 characters, should be around 8 tokens (28/4 * 1.1)
|
||||||
|
assert token_count > 7
|
||||||
|
|
||||||
|
def test_count_tokens_with_human_message(self):
|
||||||
|
"""Test counting tokens with human message"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = [HumanMessage(content="你好,这是一个测试消息。")]
|
||||||
|
token_count = context_manager.count_tokens(messages)
|
||||||
|
assert token_count > 12
|
||||||
|
|
||||||
|
def test_count_tokens_with_ai_message(self):
|
||||||
|
"""Test counting tokens with AI message"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = [AIMessage(content="I'm doing well, thank you for asking!")]
|
||||||
|
token_count = context_manager.count_tokens(messages)
|
||||||
|
assert token_count >= 10
|
||||||
|
|
||||||
|
def test_count_tokens_with_tool_message(self):
|
||||||
|
"""Test counting tokens with tool message"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = [
|
||||||
|
ToolMessage(content="Tool execution result data here", tool_call_id="test")
|
||||||
|
]
|
||||||
|
token_count = context_manager.count_tokens(messages)
|
||||||
|
# Tool message has about 32 characters, should be around 10 tokens (32/4 * 1.3)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_tokens_with_multiple_messages(self):
|
||||||
|
"""Test counting tokens with multiple messages"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
HumanMessage(content="Hello, how are you?"),
|
||||||
|
AIMessage(content="I'm doing well, thank you for asking!"),
|
||||||
|
]
|
||||||
|
token_count = context_manager.count_tokens(messages)
|
||||||
|
# Should be sum of all individual message tokens
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_is_over_limit_when_under_limit(self):
|
||||||
|
"""Test is_over_limit when messages are under token limit"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
short_messages = [HumanMessage(content="Short message")]
|
||||||
|
is_over = context_manager.is_over_limit(short_messages)
|
||||||
|
assert is_over is False
|
||||||
|
|
||||||
|
def test_is_over_limit_when_over_limit(self):
|
||||||
|
"""Test is_over_limit when messages exceed token limit"""
|
||||||
|
# Create a context manager with a very low limit
|
||||||
|
low_limit_cm = ContextManager(token_limit=5)
|
||||||
|
long_messages = [
|
||||||
|
HumanMessage(
|
||||||
|
content="This is a very long message that should exceed the limit"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
is_over = low_limit_cm.is_over_limit(long_messages)
|
||||||
|
assert is_over is True
|
||||||
|
|
||||||
|
def test_compress_messages_when_not_over_limit(self):
|
||||||
|
"""Test compress_messages when messages are not over limit"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = [HumanMessage(content="Short message")]
|
||||||
|
compressed = context_manager.compress_messages({"messages": messages})
|
||||||
|
# Should return the same messages when not over limit
|
||||||
|
assert len(compressed["messages"]) == len(messages)
|
||||||
|
|
||||||
|
def test_compress_messages_with_system_message(self):
|
||||||
|
"""Test compress_messages preserves system message"""
|
||||||
|
# Create a context manager with limited token capacity
|
||||||
|
limited_cm = ContextManager(token_limit=200)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
AIMessage(content="Hi there!"),
|
||||||
|
HumanMessage(
|
||||||
|
content="Can you tell me a very long story that would exceed token limits? "
|
||||||
|
* 100
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
compressed = limited_cm.compress_messages({"messages": messages})
|
||||||
|
# Should preserve system message and some recent messages
|
||||||
|
assert len(compressed["messages"]) == 1
|
||||||
|
|
||||||
|
def test_compress_messages_with_preserve_prefix_message(self):
|
||||||
|
"""Test compress_messages when no system message is present"""
|
||||||
|
# Create a context manager with limited token capacity
|
||||||
|
limited_cm = ContextManager(token_limit=100, preserve_prefix_message_count=2)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
AIMessage(content="Hi there!"),
|
||||||
|
HumanMessage(
|
||||||
|
content="Can you tell me a very long story that would exceed token limits? "
|
||||||
|
* 10
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
compressed = limited_cm.compress_messages({"messages": messages})
|
||||||
|
# Should keep only the most recent messages that fit
|
||||||
|
assert len(compressed["messages"]) == 3
|
||||||
|
|
||||||
|
def test_compress_messages_without_config(self):
|
||||||
|
"""Test compress_messages preserves system message"""
|
||||||
|
# Create a context manager with limited token capacity
|
||||||
|
limited_cm = ContextManager(None)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
AIMessage(content="Hi there!"),
|
||||||
|
HumanMessage(
|
||||||
|
content="Can you tell me a very long story that would exceed token limits? "
|
||||||
|
* 100
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
compressed = limited_cm.compress_messages({"messages": messages})
|
||||||
|
# return the original messages
|
||||||
|
assert len(compressed["messages"]) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_message_tokens_with_additional_kwargs(self):
|
||||||
|
"""Test counting tokens for messages with additional kwargs"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
message = ToolMessage(
|
||||||
|
content="Tool result",
|
||||||
|
tool_call_id="test",
|
||||||
|
additional_kwargs={"tool_calls": [{"name": "test_function"}]},
|
||||||
|
)
|
||||||
|
token_count = context_manager._count_message_tokens(message)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_message_tokens_minimum_one_token(self):
|
||||||
|
"""Test that message token count is at least 1"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
message = HumanMessage(content="") # Empty content
|
||||||
|
token_count = context_manager._count_message_tokens(message)
|
||||||
|
assert token_count == 1 # Should be at least 1
|
||||||
|
|
||||||
|
def test_count_text_tokens_english_only(self):
|
||||||
|
"""Test counting tokens for English text"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
# 16 English characters should result in 4 tokens (16/4)
|
||||||
|
text = "This is a test."
|
||||||
|
token_count = context_manager._count_text_tokens(text)
|
||||||
|
assert token_count > 0
|
||||||
|
|
||||||
|
def test_count_text_tokens_chinese_only(self):
|
||||||
|
"""Test counting tokens for Chinese text"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
# 8 Chinese characters should result in 8 tokens (1:1 ratio)
|
||||||
|
text = "这是一个测试文本"
|
||||||
|
token_count = context_manager._count_text_tokens(text)
|
||||||
|
assert token_count == 8
|
||||||
|
|
||||||
|
def test_count_text_tokens_mixed_content(self):
|
||||||
|
"""Test counting tokens for mixed English and Chinese text"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
text = "Hello world 这是一些中文"
|
||||||
|
token_count = context_manager._count_text_tokens(text)
|
||||||
|
assert token_count > 6
|
||||||
Reference in New Issue
Block a user