mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-14 02:34:46 +08:00
feat: add context compress (#590)
* feat:Add context compress * feat: Add unit test * feat: add unit test for context manager * feat: add postprocessor param && code format * feat: add configuration guide * fix: fix the configuration_guide * fix: fix the unit test * fix: fix the default value * feat: add test and log for context_manager
This commit is contained in:
@@ -6,16 +6,17 @@ import logging
|
||||
import os
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import tool
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langgraph.types import Command, interrupt
|
||||
from functools import partial
|
||||
|
||||
from src.agents import create_agent
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.config.configuration import Configuration
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.llms.llm import get_llm_by_type, get_llm_token_limit_by_type
|
||||
from src.prompts.planner_model import Plan
|
||||
from src.prompts.template import apply_prompt_template
|
||||
from src.tools import (
|
||||
@@ -26,6 +27,7 @@ from src.tools import (
|
||||
)
|
||||
from src.tools.search import LoggedTavilySearch
|
||||
from src.utils.json_utils import repair_json_output
|
||||
from src.utils.context_manager import ContextManager
|
||||
|
||||
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
|
||||
from .types import State
|
||||
@@ -283,13 +285,22 @@ def reporter_node(state: State, config: RunnableConfig):
|
||||
)
|
||||
)
|
||||
|
||||
observation_messages = []
|
||||
for observation in observations:
|
||||
invoke_messages.append(
|
||||
observation_messages.append(
|
||||
HumanMessage(
|
||||
content=f"Below are some observations for the research task:\n\n{observation}",
|
||||
name="observation",
|
||||
)
|
||||
)
|
||||
|
||||
# Context compression
|
||||
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP["reporter"])
|
||||
compressed_state = ContextManager(llm_token_limit).compress_messages(
|
||||
{"messages": observation_messages}
|
||||
)
|
||||
invoke_messages += compressed_state.get("messages", [])
|
||||
|
||||
logger.debug(f"Current invoke messages: {invoke_messages}")
|
||||
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
|
||||
response_content = response.content
|
||||
@@ -469,11 +480,20 @@ async def _setup_and_execute_agent_step(
|
||||
f"Powered by '{enabled_tools[tool.name]}'.\n{tool.description}"
|
||||
)
|
||||
loaded_tools.append(tool)
|
||||
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
|
||||
|
||||
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
|
||||
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
|
||||
agent = create_agent(
|
||||
agent_type, agent_type, loaded_tools, agent_type, pre_model_hook
|
||||
)
|
||||
return await _execute_agent_step(state, agent, agent_type)
|
||||
else:
|
||||
# Use default tools if no MCP servers are configured
|
||||
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
|
||||
llm_token_limit = get_llm_token_limit_by_type(AGENT_LLM_MAP[agent_type])
|
||||
pre_model_hook = partial(ContextManager(llm_token_limit, 3).compress_messages)
|
||||
agent = create_agent(
|
||||
agent_type, agent_type, default_tools, agent_type, pre_model_hook
|
||||
)
|
||||
return await _execute_agent_step(state, agent, agent_type)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user