Files
deer-flow/backend/src/agents/memory/updater.py

320 lines
10 KiB
Python
Raw Normal View History

"""Memory updater for reading, writing, and updating memory data."""
import json
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
from src.agents.memory.prompt import (
MEMORY_UPDATE_PROMPT,
format_conversation_for_update,
)
from src.config.memory_config import get_memory_config
from src.config.paths import get_paths
from src.models import create_chat_model
def _get_memory_file_path() -> Path:
"""Get the path to the memory file."""
config = get_memory_config()
if config.storage_path:
p = Path(config.storage_path)
# Absolute path: use as-is; relative path: resolve against base_dir
return p if p.is_absolute() else get_paths().base_dir / p
return get_paths().memory_file
def _create_empty_memory() -> dict[str, Any]:
"""Create an empty memory structure."""
return {
"version": "1.0",
"lastUpdated": datetime.utcnow().isoformat() + "Z",
"user": {
"workContext": {"summary": "", "updatedAt": ""},
"personalContext": {"summary": "", "updatedAt": ""},
"topOfMind": {"summary": "", "updatedAt": ""},
},
"history": {
"recentMonths": {"summary": "", "updatedAt": ""},
"earlierContext": {"summary": "", "updatedAt": ""},
"longTermBackground": {"summary": "", "updatedAt": ""},
},
"facts": [],
}
# Global memory data cache
_memory_data: dict[str, Any] | None = None
# Track file modification time for cache invalidation
_memory_file_mtime: float | None = None
def get_memory_data() -> dict[str, Any]:
"""Get the current memory data (cached with file modification time check).
The cache is automatically invalidated if the memory file has been modified
since the last load, ensuring fresh data is always returned.
Returns:
The memory data dictionary.
"""
global _memory_data, _memory_file_mtime
file_path = _get_memory_file_path()
# Get current file modification time
try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
current_mtime = None
# Invalidate cache if file has been modified or doesn't exist
if _memory_data is None or _memory_file_mtime != current_mtime:
_memory_data = _load_memory_from_file()
_memory_file_mtime = current_mtime
return _memory_data
def reload_memory_data() -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation.
Returns:
The reloaded memory data dictionary.
"""
global _memory_data, _memory_file_mtime
file_path = _get_memory_file_path()
_memory_data = _load_memory_from_file()
# Update file modification time after reload
try:
_memory_file_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
_memory_file_mtime = None
return _memory_data
def _load_memory_from_file() -> dict[str, Any]:
"""Load memory data from file.
Returns:
The memory data dictionary.
"""
file_path = _get_memory_file_path()
if not file_path.exists():
return _create_empty_memory()
try:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
return data
except (json.JSONDecodeError, OSError) as e:
print(f"Failed to load memory file: {e}")
return _create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any]) -> bool:
"""Save memory data to file and update cache.
Args:
memory_data: The memory data to save.
Returns:
True if successful, False otherwise.
"""
global _memory_data, _memory_file_mtime
file_path = _get_memory_file_path()
try:
# Ensure directory exists
file_path.parent.mkdir(parents=True, exist_ok=True)
# Update lastUpdated timestamp
memory_data["lastUpdated"] = datetime.utcnow().isoformat() + "Z"
# Write atomically using temp file
temp_path = file_path.with_suffix(".tmp")
with open(temp_path, "w", encoding="utf-8") as f:
json.dump(memory_data, f, indent=2, ensure_ascii=False)
# Rename temp file to actual file (atomic on most systems)
temp_path.replace(file_path)
# Update cache and file modification time
_memory_data = memory_data
try:
_memory_file_mtime = file_path.stat().st_mtime
except OSError:
_memory_file_mtime = None
print(f"Memory saved to {file_path}")
return True
except OSError as e:
print(f"Failed to save memory file: {e}")
return False
class MemoryUpdater:
"""Updates memory using LLM based on conversation context."""
def __init__(self, model_name: str | None = None):
"""Initialize the memory updater.
Args:
model_name: Optional model name to use. If None, uses config or default.
"""
self._model_name = model_name
def _get_model(self):
"""Get the model for memory updates."""
config = get_memory_config()
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
def update_memory(self, messages: list[Any], thread_id: str | None = None) -> bool:
"""Update memory based on conversation messages.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID for tracking source.
Returns:
True if update was successful, False otherwise.
"""
config = get_memory_config()
if not config.enabled:
return False
if not messages:
return False
try:
# Get current memory
current_memory = get_memory_data()
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = str(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Save
return _save_memory_to_file(updated_memory)
except json.JSONDecodeError as e:
print(f"Failed to parse LLM response for memory update: {e}")
return False
except Exception as e:
print(f"Memory update failed: {e}")
return False
def _apply_updates(
self,
current_memory: dict[str, Any],
update_data: dict[str, Any],
thread_id: str | None = None,
) -> dict[str, Any]:
"""Apply LLM-generated updates to memory.
Args:
current_memory: Current memory data.
update_data: Updates from LLM.
thread_id: Optional thread ID for tracking.
Returns:
Updated memory data.
"""
config = get_memory_config()
now = datetime.utcnow().isoformat() + "Z"
# Update user sections
user_updates = update_data.get("user", {})
for section in ["workContext", "personalContext", "topOfMind"]:
section_data = user_updates.get(section, {})
if section_data.get("shouldUpdate") and section_data.get("summary"):
current_memory["user"][section] = {
"summary": section_data["summary"],
"updatedAt": now,
}
# Update history sections
history_updates = update_data.get("history", {})
for section in ["recentMonths", "earlierContext", "longTermBackground"]:
section_data = history_updates.get(section, {})
if section_data.get("shouldUpdate") and section_data.get("summary"):
current_memory["history"][section] = {
"summary": section_data["summary"],
"updatedAt": now,
}
# Remove facts
facts_to_remove = set(update_data.get("factsToRemove", []))
if facts_to_remove:
2026-02-05 19:59:25 +08:00
current_memory["facts"] = [f for f in current_memory.get("facts", []) if f.get("id") not in facts_to_remove]
# Add new facts
new_facts = update_data.get("newFacts", [])
for fact in new_facts:
confidence = fact.get("confidence", 0.5)
if confidence >= config.fact_confidence_threshold:
fact_entry = {
"id": f"fact_{uuid.uuid4().hex[:8]}",
"content": fact.get("content", ""),
"category": fact.get("category", "context"),
"confidence": confidence,
"createdAt": now,
"source": thread_id or "unknown",
}
current_memory["facts"].append(fact_entry)
# Enforce max facts limit
if len(current_memory["facts"]) > config.max_facts:
# Sort by confidence and keep top ones
current_memory["facts"] = sorted(
current_memory["facts"],
key=lambda f: f.get("confidence", 0),
reverse=True,
)[: config.max_facts]
return current_memory
2026-02-05 19:59:25 +08:00
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None) -> bool:
"""Convenience function to update memory from a conversation.
Args:
messages: List of conversation messages.
thread_id: Optional thread ID.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id)