feat: add context compress (#590)

* feat:Add context compress

* feat: Add unit test

* feat: add unit test for context manager

* feat: add postprocessor param && code format

* feat: add configuration guide

* fix: fix the configuration_guide

* fix: fix the unit test

* fix: fix the default value

* feat: add test and log for context_manager
This commit is contained in:
Fancy-hjyp
2025-09-27 06:42:22 -07:00
committed by GitHub
parent c214999606
commit 5f4eb38fdb
9 changed files with 1032 additions and 7 deletions

View File

@@ -6,16 +6,17 @@ import logging
import os
from typing import Annotated, Literal
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import tool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.types import Command, interrupt
from functools import partial
from src.agents import create_agent
from src.config.agents import AGENT_LLM_MAP
from src.config.configuration import Configuration
from src.llms.llm import get_llm_by_type
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
from src.prompts.planner_model import Plan
from src.prompts.template import apply_prompt_template
from src.tools import (
@@ -26,6 +27,7 @@ from src.tools import (
)
from src.tools.search import LoggedTavilySearch
from src.utils.json_utils import repair_json_output
from src.utils.context_manager import ContextManager
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
from .types import State
@@ -283,13 +285,22 @@ def reporter_node(state: State, config: RunnableConfig):
)
)
observation_messages = []
for observation in observations:
invoke_messages.append(
observation_messages.append(
HumanMessage(
content=f"Below are some observations for the research task:\n\n{observation}",
name="observation",
)
)
# Context compression
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP["reporter"])
compressed_state = ContextManager(llm_token_limit).compress_messages(
{"messages": observation_messages}
)
invoke_messages += compressed_state.get("messages", [])
logger.debug(f"Current invoke messages: {invoke_messages}")
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
response_content = response.content
@@ -469,11 +480,20 @@ async def _setup_and_execute_agent_step(
f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}"
)
loaded_tools.append(tool)
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, loaded_tools, agent_type, pre_model_hook
)
return await _execute_agent_step(state, agent, agent_type)
else:
# Use default tools if no MCP servers are configured
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
agent = create_agent(
agent_type, agent_type, default_tools, agent_type, pre_model_hook
)
return await _execute_agent_step(state, agent, agent_type)