From 1c542ab7f1524f6248412e6ba004d985a549c10b Mon Sep 17 00:00:00 2001 From: knukn Date: Fri, 27 Mar 2026 07:41:06 +0800 Subject: [PATCH] feat(memory): Introduce configurable memory storage abstraction (#1353) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(内存存储): 添加可配置的内存存储提供者支持 实现内存存储的抽象基类 MemoryStorage 和文件存储实现 FileMemoryStorage 重构内存数据加载和保存逻辑到存储提供者中 添加 storage_class 配置项以支持自定义存储提供者 * refactor(memory): 重构内存存储模块并更新相关测试 将内存存储逻辑从updater模块移动到独立的storage模块 使用存储接口模式替代直接文件操作 更新所有相关测试以使用新的存储接口 * Update backend/packages/harness/deerflow/agents/memory/storage.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update backend/packages/harness/deerflow/agents/memory/storage.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(内存存储): 添加线程安全锁并增加测试用例 添加线程锁确保内存存储单例初始化的线程安全 增加对无效代理名称的验证测试 补充单例线程安全性和异常处理的测试用例 * Update backend/tests/test_memory_storage.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix(agents): 使用统一模式验证代理名称 修改代理名称验证逻辑以使用仓库中定义的AGENT_NAME_PATTERN模式,确保代码库一致性并防止路径遍历等安全问题。同时更新测试用例以覆盖更多无效名称情况。 --------- Co-authored-by: Willem Jiang Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../deerflow/agents/memory/__init__.py | 9 + .../harness/deerflow/agents/memory/storage.py | 205 ++++++++++++++++++ .../harness/deerflow/agents/memory/updater.py | 167 +------------- .../harness/deerflow/config/memory_config.py | 4 + backend/tests/test_custom_agent.py | 31 +-- backend/tests/test_memory_storage.py | 199 +++++++++++++++++ backend/tests/test_memory_updater.py | 4 +- 7 files changed, 442 insertions(+), 177 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/memory/storage.py create mode 100644 backend/tests/test_memory_storage.py diff --git a/backend/packages/harness/deerflow/agents/memory/__init__.py b/backend/packages/harness/deerflow/agents/memory/__init__.py index 6199964..218c0be 100644 --- a/backend/packages/harness/deerflow/agents/memory/__init__.py +++ b/backend/packages/harness/deerflow/agents/memory/__init__.py @@ -18,6 +18,11 @@ from deerflow.agents.memory.queue import ( get_memory_queue, reset_memory_queue, ) +from deerflow.agents.memory.storage import ( + FileMemoryStorage, + MemoryStorage, + get_memory_storage, +) from deerflow.agents.memory.updater import ( MemoryUpdater, get_memory_data, @@ -36,6 +41,10 @@ __all__ = [ "MemoryUpdateQueue", "get_memory_queue", "reset_memory_queue", + # Storage + "MemoryStorage", + "FileMemoryStorage", + "get_memory_storage", # Updater "MemoryUpdater", "get_memory_data", diff --git a/backend/packages/harness/deerflow/agents/memory/storage.py b/backend/packages/harness/deerflow/agents/memory/storage.py new file mode 100644 index 0000000..fcd1298 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/memory/storage.py @@ -0,0 +1,205 @@ +"""Memory storage providers.""" + +import abc +import json +import logging +import threading +from datetime import datetime +from pathlib import Path +from typing import Any + +from deerflow.config.agents_config import AGENT_NAME_PATTERN +from deerflow.config.memory_config import get_memory_config +from deerflow.config.paths import get_paths + +logger = logging.getLogger(__name__) + + +def create_empty_memory() -> dict[str, Any]: + """Create an empty memory structure.""" + return { + "version": "1.0", + "lastUpdated": datetime.utcnow().isoformat() + "Z", + "user": { + "workContext": {"summary": "", "updatedAt": ""}, + "personalContext": {"summary": "", "updatedAt": ""}, + "topOfMind": {"summary": "", "updatedAt": ""}, + }, + "history": { + "recentMonths": {"summary": "", "updatedAt": ""}, + "earlierContext": {"summary": "", "updatedAt": ""}, + "longTermBackground": {"summary": "", "updatedAt": ""}, + }, + "facts": [], + } + + +class MemoryStorage(abc.ABC): + """Abstract base class for memory storage providers.""" + + @abc.abstractmethod + def load(self, agent_name: str | None = None) -> dict[str, Any]: + """Load memory data for the given agent.""" + pass + + @abc.abstractmethod + def reload(self, agent_name: str | None = None) -> dict[str, Any]: + """Force reload memory data for the given agent.""" + pass + + @abc.abstractmethod + def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool: + """Save memory data for the given agent.""" + pass + + +class FileMemoryStorage(MemoryStorage): + """File-based memory storage provider.""" + + def __init__(self): + """Initialize the file memory storage.""" + # Per-agent memory cache: keyed by agent_name (None = global) + # Value: (memory_data, file_mtime) + self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {} + + def _validate_agent_name(self, agent_name: str) -> None: + """Validate that the agent name is safe to use in filesystem paths. + + Uses the repository's established AGENT_NAME_PATTERN to ensure consistency + across the codebase and prevent path traversal or other problematic characters. + """ + if not agent_name: + raise ValueError("Agent name must be a non-empty string.") + if not AGENT_NAME_PATTERN.match(agent_name): + raise ValueError( + f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}" + ) + + def _get_memory_file_path(self, agent_name: str | None = None) -> Path: + """Get the path to the memory file.""" + if agent_name is not None: + self._validate_agent_name(agent_name) + return get_paths().agent_memory_file(agent_name) + + config = get_memory_config() + if config.storage_path: + p = Path(config.storage_path) + return p if p.is_absolute() else get_paths().base_dir / p + return get_paths().memory_file + + def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]: + """Load memory data from file.""" + file_path = self._get_memory_file_path(agent_name) + + if not file_path.exists(): + return create_empty_memory() + + try: + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + return data + except (json.JSONDecodeError, OSError) as e: + logger.warning("Failed to load memory file: %s", e) + return create_empty_memory() + + def load(self, agent_name: str | None = None) -> dict[str, Any]: + """Load memory data (cached with file modification time check).""" + file_path = self._get_memory_file_path(agent_name) + + try: + current_mtime = file_path.stat().st_mtime if file_path.exists() else None + except OSError: + current_mtime = None + + cached = self._memory_cache.get(agent_name) + + if cached is None or cached[1] != current_mtime: + memory_data = self._load_memory_from_file(agent_name) + self._memory_cache[agent_name] = (memory_data, current_mtime) + return memory_data + + return cached[0] + + def reload(self, agent_name: str | None = None) -> dict[str, Any]: + """Reload memory data from file, forcing cache invalidation.""" + file_path = self._get_memory_file_path(agent_name) + memory_data = self._load_memory_from_file(agent_name) + + try: + mtime = file_path.stat().st_mtime if file_path.exists() else None + except OSError: + mtime = None + + self._memory_cache[agent_name] = (memory_data, mtime) + return memory_data + + def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool: + """Save memory data to file and update cache.""" + file_path = self._get_memory_file_path(agent_name) + + try: + file_path.parent.mkdir(parents=True, exist_ok=True) + memory_data["lastUpdated"] = datetime.utcnow().isoformat() + "Z" + + temp_path = file_path.with_suffix(".tmp") + with open(temp_path, "w", encoding="utf-8") as f: + json.dump(memory_data, f, indent=2, ensure_ascii=False) + + temp_path.replace(file_path) + + try: + mtime = file_path.stat().st_mtime + except OSError: + mtime = None + + self._memory_cache[agent_name] = (memory_data, mtime) + logger.info("Memory saved to %s", file_path) + return True + except OSError as e: + logger.error("Failed to save memory file: %s", e) + return False + + +_storage_instance: MemoryStorage | None = None +_storage_lock = threading.Lock() + + +def get_memory_storage() -> MemoryStorage: + """Get the configured memory storage instance.""" + global _storage_instance + if _storage_instance is not None: + return _storage_instance + + with _storage_lock: + if _storage_instance is not None: + return _storage_instance + + config = get_memory_config() + storage_class_path = config.storage_class + + try: + module_path, class_name = storage_class_path.rsplit(".", 1) + import importlib + module = importlib.import_module(module_path) + storage_class = getattr(module, class_name) + + # Validate that the configured storage is a MemoryStorage implementation + if not isinstance(storage_class, type): + raise TypeError( + f"Configured memory storage '{storage_class_path}' is not a class: {storage_class!r}" + ) + if not issubclass(storage_class, MemoryStorage): + raise TypeError( + f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage" + ) + + _storage_instance = storage_class() + except Exception as e: + logger.error( + "Failed to load memory storage %s, falling back to FileMemoryStorage: %s", + storage_class_path, + e, + ) + _storage_instance = FileMemoryStorage() + + return _storage_instance diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index e5d37d0..e90163b 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -5,115 +5,25 @@ import logging import re import uuid from datetime import datetime -from pathlib import Path from typing import Any from deerflow.agents.memory.prompt import ( MEMORY_UPDATE_PROMPT, format_conversation_for_update, ) +from deerflow.agents.memory.storage import get_memory_storage from deerflow.config.memory_config import get_memory_config -from deerflow.config.paths import get_paths from deerflow.models import create_chat_model logger = logging.getLogger(__name__) - -def _get_memory_file_path(agent_name: str | None = None) -> Path: - """Get the path to the memory file. - - Args: - agent_name: If provided, returns the per-agent memory file path. - If None, returns the global memory file path. - - Returns: - Path to the memory file. - """ - if agent_name is not None: - return get_paths().agent_memory_file(agent_name) - - config = get_memory_config() - if config.storage_path: - p = Path(config.storage_path) - # Absolute path: use as-is; relative path: resolve against base_dir - return p if p.is_absolute() else get_paths().base_dir / p - return get_paths().memory_file - - -def _create_empty_memory() -> dict[str, Any]: - """Create an empty memory structure.""" - return { - "version": "1.0", - "lastUpdated": datetime.utcnow().isoformat() + "Z", - "user": { - "workContext": {"summary": "", "updatedAt": ""}, - "personalContext": {"summary": "", "updatedAt": ""}, - "topOfMind": {"summary": "", "updatedAt": ""}, - }, - "history": { - "recentMonths": {"summary": "", "updatedAt": ""}, - "earlierContext": {"summary": "", "updatedAt": ""}, - "longTermBackground": {"summary": "", "updatedAt": ""}, - }, - "facts": [], - } - - -# Per-agent memory cache: keyed by agent_name (None = global) -# Value: (memory_data, file_mtime) -_memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {} - - def get_memory_data(agent_name: str | None = None) -> dict[str, Any]: - """Get the current memory data (cached with file modification time check). - - The cache is automatically invalidated if the memory file has been modified - since the last load, ensuring fresh data is always returned. - - Args: - agent_name: If provided, loads per-agent memory. If None, loads global memory. - - Returns: - The memory data dictionary. - """ - file_path = _get_memory_file_path(agent_name) - - # Get current file modification time - try: - current_mtime = file_path.stat().st_mtime if file_path.exists() else None - except OSError: - current_mtime = None - - cached = _memory_cache.get(agent_name) - - # Invalidate cache if file has been modified or doesn't exist - if cached is None or cached[1] != current_mtime: - memory_data = _load_memory_from_file(agent_name) - _memory_cache[agent_name] = (memory_data, current_mtime) - return memory_data - - return cached[0] - + """Get the current memory data via storage provider.""" + return get_memory_storage().load(agent_name) def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]: - """Reload memory data from file, forcing cache invalidation. - - Args: - agent_name: If provided, reloads per-agent memory. If None, reloads global memory. - - Returns: - The reloaded memory data dictionary. - """ - file_path = _get_memory_file_path(agent_name) - memory_data = _load_memory_from_file(agent_name) - - try: - mtime = file_path.stat().st_mtime if file_path.exists() else None - except OSError: - mtime = None - - _memory_cache[agent_name] = (memory_data, mtime) - return memory_data + """Reload memory data via storage provider.""" + return get_memory_storage().reload(agent_name) def _extract_text(content: Any) -> str: @@ -153,29 +63,6 @@ def _extract_text(content: Any) -> str: return str(content) -def _load_memory_from_file(agent_name: str | None = None) -> dict[str, Any]: - """Load memory data from file. - - Args: - agent_name: If provided, loads per-agent memory file. If None, loads global. - - Returns: - The memory data dictionary. - """ - file_path = _get_memory_file_path(agent_name) - - if not file_path.exists(): - return _create_empty_memory() - - try: - with open(file_path, encoding="utf-8") as f: - data = json.load(f) - return data - except (json.JSONDecodeError, OSError) as e: - logger.warning("Failed to load memory file: %s", e) - return _create_empty_memory() - - # Matches sentences that describe a file-upload *event* rather than general # file-related work. Deliberately narrow to avoid removing legitimate facts # such as "User works with CSV files" or "prefers PDF export". @@ -222,48 +109,6 @@ def _fact_content_key(content: Any) -> str | None: return stripped -def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool: - """Save memory data to file and update cache. - - Args: - memory_data: The memory data to save. - agent_name: If provided, saves to per-agent memory file. If None, saves to global. - - Returns: - True if successful, False otherwise. - """ - file_path = _get_memory_file_path(agent_name) - - try: - # Ensure directory exists - file_path.parent.mkdir(parents=True, exist_ok=True) - - # Update lastUpdated timestamp - memory_data["lastUpdated"] = datetime.utcnow().isoformat() + "Z" - - # Write atomically using temp file - temp_path = file_path.with_suffix(".tmp") - with open(temp_path, "w", encoding="utf-8") as f: - json.dump(memory_data, f, indent=2, ensure_ascii=False) - - # Rename temp file to actual file (atomic on most systems) - temp_path.replace(file_path) - - # Update cache and file modification time - try: - mtime = file_path.stat().st_mtime - except OSError: - mtime = None - - _memory_cache[agent_name] = (memory_data, mtime) - - logger.info("Memory saved to %s", file_path) - return True - except OSError as e: - logger.error("Failed to save memory file: %s", e) - return False - - class MemoryUpdater: """Updates memory using LLM based on conversation context.""" @@ -338,7 +183,7 @@ class MemoryUpdater: updated_memory = _strip_upload_mentions_from_memory(updated_memory) # Save - return _save_memory_to_file(updated_memory, agent_name) + return get_memory_storage().save(updated_memory, agent_name) except json.JSONDecodeError as e: logger.warning("Failed to parse LLM response for memory update: %s", e) diff --git a/backend/packages/harness/deerflow/config/memory_config.py b/backend/packages/harness/deerflow/config/memory_config.py index 824717d..8565aa2 100644 --- a/backend/packages/harness/deerflow/config/memory_config.py +++ b/backend/packages/harness/deerflow/config/memory_config.py @@ -23,6 +23,10 @@ class MemoryConfig(BaseModel): "migrate existing data or use an absolute path to preserve the old location." ), ) + storage_class: str = Field( + default="deerflow.agents.memory.storage.FileMemoryStorage", + description="The class path for memory storage provider", + ) debounce_seconds: int = Field( default=30, ge=1, diff --git a/backend/tests/test_custom_agent.py b/backend/tests/test_custom_agent.py index 6dfbffc..e2b4b63 100644 --- a/backend/tests/test_custom_agent.py +++ b/backend/tests/test_custom_agent.py @@ -304,39 +304,42 @@ class TestListCustomAgents: class TestMemoryFilePath: def test_global_memory_path(self, tmp_path): """None agent_name should return global memory file.""" - import deerflow.agents.memory.updater as updater_mod + from deerflow.agents.memory.storage import FileMemoryStorage from deerflow.config.memory_config import MemoryConfig with ( - patch("deerflow.agents.memory.updater.get_paths", return_value=_make_paths(tmp_path)), - patch("deerflow.agents.memory.updater.get_memory_config", return_value=MemoryConfig(storage_path="")), + patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)), + patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")), ): - path = updater_mod._get_memory_file_path(None) + storage = FileMemoryStorage() + path = storage._get_memory_file_path(None) assert path == tmp_path / "memory.json" def test_agent_memory_path(self, tmp_path): """Providing agent_name should return per-agent memory file.""" - import deerflow.agents.memory.updater as updater_mod + from deerflow.agents.memory.storage import FileMemoryStorage from deerflow.config.memory_config import MemoryConfig with ( - patch("deerflow.agents.memory.updater.get_paths", return_value=_make_paths(tmp_path)), - patch("deerflow.agents.memory.updater.get_memory_config", return_value=MemoryConfig(storage_path="")), + patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)), + patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")), ): - path = updater_mod._get_memory_file_path("code-reviewer") + storage = FileMemoryStorage() + path = storage._get_memory_file_path("code-reviewer") assert path == tmp_path / "agents" / "code-reviewer" / "memory.json" def test_different_paths_for_different_agents(self, tmp_path): - import deerflow.agents.memory.updater as updater_mod + from deerflow.agents.memory.storage import FileMemoryStorage from deerflow.config.memory_config import MemoryConfig with ( - patch("deerflow.agents.memory.updater.get_paths", return_value=_make_paths(tmp_path)), - patch("deerflow.agents.memory.updater.get_memory_config", return_value=MemoryConfig(storage_path="")), + patch("deerflow.agents.memory.storage.get_paths", return_value=_make_paths(tmp_path)), + patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")), ): - path_global = updater_mod._get_memory_file_path(None) - path_a = updater_mod._get_memory_file_path("agent-a") - path_b = updater_mod._get_memory_file_path("agent-b") + storage = FileMemoryStorage() + path_global = storage._get_memory_file_path(None) + path_a = storage._get_memory_file_path("agent-a") + path_b = storage._get_memory_file_path("agent-b") assert path_global != path_a assert path_global != path_b diff --git a/backend/tests/test_memory_storage.py b/backend/tests/test_memory_storage.py new file mode 100644 index 0000000..dd6b9f6 --- /dev/null +++ b/backend/tests/test_memory_storage.py @@ -0,0 +1,199 @@ +"""Tests for memory storage providers.""" + +import threading +from unittest.mock import MagicMock, patch + +import pytest + +from deerflow.agents.memory.storage import ( + FileMemoryStorage, + MemoryStorage, + create_empty_memory, + get_memory_storage, +) +from deerflow.config.memory_config import MemoryConfig + + +class TestCreateEmptyMemory: + """Test create_empty_memory function.""" + + def test_returns_valid_structure(self): + """Should return a valid empty memory structure.""" + memory = create_empty_memory() + assert isinstance(memory, dict) + assert memory["version"] == "1.0" + assert "lastUpdated" in memory + assert isinstance(memory["user"], dict) + assert isinstance(memory["history"], dict) + assert isinstance(memory["facts"], list) + + +class TestMemoryStorageInterface: + """Test MemoryStorage abstract base class.""" + + def test_abstract_methods(self): + """Should raise TypeError when trying to instantiate abstract class.""" + class TestStorage(MemoryStorage): + pass + + with pytest.raises(TypeError): + TestStorage() + + +class TestFileMemoryStorage: + """Test FileMemoryStorage implementation.""" + + def test_get_memory_file_path_global(self, tmp_path): + """Should return global memory file path when agent_name is None.""" + def mock_get_paths(): + mock_paths = MagicMock() + mock_paths.memory_file = tmp_path / "memory.json" + return mock_paths + + with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths): + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")): + storage = FileMemoryStorage() + path = storage._get_memory_file_path(None) + assert path == tmp_path / "memory.json" + + def test_get_memory_file_path_agent(self, tmp_path): + """Should return per-agent memory file path when agent_name is provided.""" + def mock_get_paths(): + mock_paths = MagicMock() + mock_paths.agent_memory_file.return_value = tmp_path / "agents" / "test-agent" / "memory.json" + return mock_paths + + with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths): + storage = FileMemoryStorage() + path = storage._get_memory_file_path("test-agent") + assert path == tmp_path / "agents" / "test-agent" / "memory.json" + + @pytest.mark.parametrize( + "invalid_name", ["", "../etc/passwd", "agent/name", "agent\\name", "agent name", "agent@123", "agent_name"] + ) + def test_validate_agent_name_invalid(self, invalid_name): + """Should raise ValueError for invalid agent names that don't match the pattern.""" + storage = FileMemoryStorage() + with pytest.raises(ValueError, match="Invalid agent name|Agent name must be a non-empty string"): + storage._validate_agent_name(invalid_name) + + def test_load_creates_empty_memory(self, tmp_path): + """Should create empty memory when file doesn't exist.""" + def mock_get_paths(): + mock_paths = MagicMock() + mock_paths.memory_file = tmp_path / "non_existent_memory.json" + return mock_paths + + with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths): + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")): + storage = FileMemoryStorage() + memory = storage.load() + assert isinstance(memory, dict) + assert memory["version"] == "1.0" + + def test_save_writes_to_file(self, tmp_path): + """Should save memory data to file.""" + memory_file = tmp_path / "memory.json" + + def mock_get_paths(): + mock_paths = MagicMock() + mock_paths.memory_file = memory_file + return mock_paths + + with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths): + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")): + storage = FileMemoryStorage() + test_memory = {"version": "1.0", "facts": [{"content": "test fact"}]} + result = storage.save(test_memory) + assert result is True + assert memory_file.exists() + + def test_reload_forces_cache_invalidation(self, tmp_path): + """Should force reload from file and invalidate cache.""" + memory_file = tmp_path / "memory.json" + memory_file.parent.mkdir(parents=True, exist_ok=True) + memory_file.write_text('{"version": "1.0", "facts": [{"content": "initial fact"}]}') + + def mock_get_paths(): + mock_paths = MagicMock() + mock_paths.memory_file = memory_file + return mock_paths + + with patch("deerflow.agents.memory.storage.get_paths", side_effect=mock_get_paths): + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_path="")): + storage = FileMemoryStorage() + # First load + memory1 = storage.load() + assert memory1["facts"][0]["content"] == "initial fact" + + # Update file directly + memory_file.write_text('{"version": "1.0", "facts": [{"content": "updated fact"}]}') + + # Reload should get updated data + memory2 = storage.reload() + assert memory2["facts"][0]["content"] == "updated fact" + + +class TestGetMemoryStorage: + """Test get_memory_storage function.""" + + @pytest.fixture(autouse=True) + def reset_storage_instance(self): + """Reset the global storage instance before and after each test.""" + import deerflow.agents.memory.storage as storage_mod + storage_mod._storage_instance = None + yield + storage_mod._storage_instance = None + + def test_returns_file_memory_storage_by_default(self): + """Should return FileMemoryStorage by default.""" + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")): + storage = get_memory_storage() + assert isinstance(storage, FileMemoryStorage) + + def test_falls_back_to_file_memory_storage_on_error(self): + """Should fall back to FileMemoryStorage if configured storage fails to load.""" + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="non.existent.StorageClass")): + storage = get_memory_storage() + assert isinstance(storage, FileMemoryStorage) + + def test_returns_singleton_instance(self): + """Should return the same instance on subsequent calls.""" + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")): + storage1 = get_memory_storage() + storage2 = get_memory_storage() + assert storage1 is storage2 + + def test_get_memory_storage_thread_safety(self): + """Should safely initialize the singleton even with concurrent calls.""" + results = [] + def get_storage(): + # get_memory_storage is called concurrently from multiple threads while + # get_memory_config is patched once around thread creation. This verifies + # that the singleton initialization remains thread-safe. + results.append(get_memory_storage()) + + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="deerflow.agents.memory.storage.FileMemoryStorage")): + threads = [threading.Thread(target=get_storage) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All results should be the exact same instance + assert len(results) == 10 + assert all(r is results[0] for r in results) + + def test_get_memory_storage_invalid_class_fallback(self): + """Should fall back to FileMemoryStorage if the configured class is not actually a class.""" + # Using a built-in function instead of a class + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="os.path.join")): + storage = get_memory_storage() + assert isinstance(storage, FileMemoryStorage) + + def test_get_memory_storage_non_subclass_fallback(self): + """Should fall back to FileMemoryStorage if the configured class is not a subclass of MemoryStorage.""" + # Using 'dict' as a class that is not a MemoryStorage subclass + with patch("deerflow.agents.memory.storage.get_memory_config", return_value=MemoryConfig(storage_class="builtins.dict")): + storage = get_memory_storage() + assert isinstance(storage, FileMemoryStorage) diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index 7ccba65..341b676 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -251,7 +251,7 @@ class TestUpdateMemoryStructuredResponse: patch.object(updater, "_get_model", return_value=self._make_mock_model(valid_json)), patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), - patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), ): msg = MagicMock() msg.type = "human" @@ -274,7 +274,7 @@ class TestUpdateMemoryStructuredResponse: patch.object(updater, "_get_model", return_value=self._make_mock_model(list_content)), patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()), - patch("deerflow.agents.memory.updater._save_memory_to_file", return_value=True), + patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))), ): msg = MagicMock() msg.type = "human"