mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-20 04:44:46 +08:00
fix: Add runtime parameter to compress_messages method(#803)
* fix: Add runtime parameter to compress_messages method(#803) The compress_messages method was being called by PreModelHookMiddleware with both state and runtime parameters, but only accepted state parameter. This caused a TypeError when the middleware executed the pre_model_hook. Added optional runtime parameter to compress_messages signature to match the expected interface while maintaining backward compatibility. * Update the code with the review comments
This commit is contained in:
@@ -3,6 +3,8 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@@ -144,12 +146,13 @@ class ContextManager:
|
|||||||
"""
|
"""
|
||||||
return self.count_tokens(messages) > self.token_limit
|
return self.count_tokens(messages) > self.token_limit
|
||||||
|
|
||||||
def compress_messages(self, state: dict) -> List[BaseMessage]:
|
def compress_messages(self, state: dict, runtime: Runtime | None = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Compress messages to fit within token limit
|
Compress messages to fit within token limit
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: state with original messages
|
state: state with original messages
|
||||||
|
runtime: Optional runtime parameter (not used but required for middleware compatibility)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Compressed state with compressed messages
|
Compressed state with compressed messages
|
||||||
|
|||||||
@@ -181,3 +181,29 @@ class TestContextManager:
|
|||||||
text = "Hello world 这是一些中文"
|
text = "Hello world 这是一些中文"
|
||||||
token_count = context_manager._count_text_tokens(text)
|
token_count = context_manager._count_text_tokens(text)
|
||||||
assert token_count > 6
|
assert token_count > 6
|
||||||
|
|
||||||
|
def test_compress_messages_with_runtime_when_not_over_limit(self):
|
||||||
|
"""compress_messages accepts runtime param when under limit"""
|
||||||
|
context_manager = ContextManager(token_limit=1000)
|
||||||
|
messages = [HumanMessage(content="Short message"), AIMessage(content="OK")]
|
||||||
|
compressed = context_manager.compress_messages({"messages": messages}, runtime=object())
|
||||||
|
assert isinstance(compressed, dict)
|
||||||
|
assert "messages" in compressed
|
||||||
|
assert len(compressed["messages"]) == len(messages)
|
||||||
|
|
||||||
|
def test_compress_messages_with_runtime_when_over_limit(self):
|
||||||
|
"""compress_messages accepts runtime param and still compresses"""
|
||||||
|
limited_cm = ContextManager(token_limit=200)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant."),
|
||||||
|
HumanMessage(content="Hello"),
|
||||||
|
AIMessage(content="Hi there!"),
|
||||||
|
HumanMessage(
|
||||||
|
content="Can you tell me a very long story that would exceed token limits? " * 100
|
||||||
|
),
|
||||||
|
]
|
||||||
|
compressed = limited_cm.compress_messages({"messages": messages}, runtime=object())
|
||||||
|
assert isinstance(compressed, dict)
|
||||||
|
assert "messages" in compressed
|
||||||
|
# Should preserve only what fits; with this setup we expect heavy compression
|
||||||
|
assert len(compressed["messages"]) == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user