import os from pathlib import Path from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime from src.agents.thread_state import ThreadDataState # Base directory for thread data (relative to backend/) THREAD_DATA_BASE_DIR = ".deer-flow/threads" class ThreadDataMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" thread_data: NotRequired[ThreadDataState | None] class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]): """Create thread data directories for each thread execution. Creates the following directory structure: - backend/.deer-flow/threads/{thread_id}/user-data/workspace - backend/.deer-flow/threads/{thread_id}/user-data/uploads - backend/.deer-flow/threads/{thread_id}/user-data/outputs """ state_schema = ThreadDataMiddlewareState def __init__(self, base_dir: str | None = None): """Initialize the middleware. Args: base_dir: Base directory for thread data. Defaults to the current working directory. """ super().__init__() self._base_dir = base_dir or os.getcwd() def _get_thread_paths(self, thread_id: str) -> dict[str, str]: """Get the paths for a thread's data directories. Args: thread_id: The thread ID. Returns: Dictionary with workspace_path, uploads_path, and outputs_path. """ thread_dir = Path(self._base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data" return { "workspace_path": str(thread_dir / "workspace"), "uploads_path": str(thread_dir / "uploads"), "outputs_path": str(thread_dir / "outputs"), } def _create_thread_directories(self, thread_id: str) -> dict[str, str]: """Create the thread data directories. Args: thread_id: The thread ID. Returns: Dictionary with the created directory paths. """ paths = self._get_thread_paths(thread_id) for path in paths.values(): os.makedirs(path, exist_ok=True) return paths @override def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None: # Generate new thread ID and create directories thread_id = runtime.context.get("thread_id") if thread_id is None: raise ValueError("Thread ID is required in the context") paths = self._create_thread_directories(thread_id) print(f"Created thread data directories for thread {thread_id}") return { "thread_data": { **paths, } }