mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
fix: Add runtime parameter to compress_messages method(#803)
* fix: Add runtime parameter to compress_messages method(#803) The compress_messages method was being called by PreModelHookMiddleware with both state and runtime parameters, but only accepted state parameter. This caused a TypeError when the middleware executed the pre_model_hook. Added optional runtime parameter to compress_messages signature to match the expected interface while maintaining backward compatibility. * Update the code with the review comments
This commit is contained in:
@@ -3,6 +3,8 @@ import copy
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@@ -144,12 +146,13 @@ class ContextManager:
|
||||
"""
|
||||
return self.count_tokens(messages) > self.token_limit
|
||||
|
||||
def compress_messages(self, state: dict) -> List[BaseMessage]:
|
||||
def compress_messages(self, state: dict, runtime: Runtime | None = None) -> dict:
|
||||
"""
|
||||
Compress messages to fit within token limit
|
||||
|
||||
Args:
|
||||
state: state with original messages
|
||||
runtime: Optional runtime parameter (not used but required for middleware compatibility)
|
||||
|
||||
Returns:
|
||||
Compressed state with compressed messages
|
||||
|
||||
@@ -181,3 +181,29 @@ class TestContextManager:
|
||||
text = "Hello world 这是一些中文"
|
||||
token_count = context_manager._count_text_tokens(text)
|
||||
assert token_count > 6
|
||||
|
||||
def test_compress_messages_with_runtime_when_not_over_limit(self):
|
||||
"""compress_messages accepts runtime param when under limit"""
|
||||
context_manager = ContextManager(token_limit=1000)
|
||||
messages = [HumanMessage(content="Short message"), AIMessage(content="OK")]
|
||||
compressed = context_manager.compress_messages({"messages": messages}, runtime=object())
|
||||
assert isinstance(compressed, dict)
|
||||
assert "messages" in compressed
|
||||
assert len(compressed["messages"]) == len(messages)
|
||||
|
||||
def test_compress_messages_with_runtime_when_over_limit(self):
|
||||
"""compress_messages accepts runtime param and still compresses"""
|
||||
limited_cm = ContextManager(token_limit=200)
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
HumanMessage(
|
||||
content="Can you tell me a very long story that would exceed token limits? " * 100
|
||||
),
|
||||
]
|
||||
compressed = limited_cm.compress_messages({"messages": messages}, runtime=object())
|
||||
assert isinstance(compressed, dict)
|
||||
assert "messages" in compressed
|
||||
# Should preserve only what fits; with this setup we expect heavy compression
|
||||
assert len(compressed["messages"]) == 1
|
||||
|
||||
Reference in New Issue
Block a user