Files
deer-flow/backend/src/agents/middlewares/memory_middleware.py

70 lines
2.3 KiB
Python
Raw Normal View History

"""Middleware for memory mechanism."""
from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.agents.memory.queue import get_memory_queue
from src.config.memory_config import get_memory_config
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
pass
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
This middleware:
1. After each agent execution, queues the conversation for memory update
2. The queue uses debouncing to batch multiple updates together
3. Memory is updated asynchronously via LLM summarization
"""
state_schema = MemoryMiddlewareState
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
"""Queue conversation for memory update after agent completes.
Args:
state: The current agent state.
runtime: The runtime context.
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
if not config.enabled:
return None
# Get thread ID from runtime context
thread_id = runtime.context.get("thread_id")
if not thread_id:
print("MemoryMiddleware: No thread_id in context, skipping memory update")
return None
# Get messages from state
messages = state.get("messages", [])
if not messages:
print("MemoryMiddleware: No messages in state, skipping memory update")
return None
# Only queue if there's meaningful conversation
# At minimum need one user message and one assistant response
user_messages = [m for m in messages if getattr(m, "type", None) == "human"]
assistant_messages = [m for m in messages if getattr(m, "type", None) == "ai"]
if not user_messages or not assistant_messages:
return None
# Queue the conversation for memory update
queue = get_memory_queue()
queue.add(thread_id=thread_id, messages=list(messages))
return None