"""Middleware for memory mechanism.""" from typing import override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime from src.agents.memory.queue import get_memory_queue from src.config.memory_config import get_memory_config class MemoryMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" pass class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): """Middleware that queues conversation for memory update after agent execution. This middleware: 1. After each agent execution, queues the conversation for memory update 2. The queue uses debouncing to batch multiple updates together 3. Memory is updated asynchronously via LLM summarization """ state_schema = MemoryMiddlewareState @override def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None: """Queue conversation for memory update after agent completes. Args: state: The current agent state. runtime: The runtime context. Returns: None (no state changes needed from this middleware). """ config = get_memory_config() if not config.enabled: return None # Get thread ID from runtime context thread_id = runtime.context.get("thread_id") if not thread_id: print("MemoryMiddleware: No thread_id in context, skipping memory update") return None # Get messages from state messages = state.get("messages", []) if not messages: print("MemoryMiddleware: No messages in state, skipping memory update") return None # Only queue if there's meaningful conversation # At minimum need one user message and one assistant response user_messages = [m for m in messages if getattr(m, "type", None) == "human"] assistant_messages = [m for m in messages if getattr(m, "type", None) == "ai"] if not user_messages or not assistant_messages: return None # Queue the conversation for memory update queue = get_memory_queue() queue.add(thread_id=thread_id, messages=list(messages)) return None