mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-20 04:44:46 +08:00
85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
import os
|
|
from pathlib import Path
|
|
from typing import NotRequired, override
|
|
|
|
from langchain.agents import AgentState
|
|
from langchain.agents.middleware import AgentMiddleware
|
|
from langgraph.runtime import Runtime
|
|
|
|
from src.agents.thread_state import ThreadDataState
|
|
|
|
# Base directory for thread data (relative to backend/)
|
|
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
|
|
|
|
|
class ThreadDataMiddlewareState(AgentState):
|
|
"""Compatible with the `ThreadState` schema."""
|
|
|
|
thread_data: NotRequired[ThreadDataState | None]
|
|
|
|
|
|
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
|
"""Create thread data directories for each thread execution.
|
|
|
|
Creates the following directory structure:
|
|
- backend/.deer-flow/threads/{thread_id}/user-data/workspace
|
|
- backend/.deer-flow/threads/{thread_id}/user-data/uploads
|
|
- backend/.deer-flow/threads/{thread_id}/user-data/outputs
|
|
"""
|
|
|
|
state_schema = ThreadDataMiddlewareState
|
|
|
|
def __init__(self, base_dir: str | None = None):
|
|
"""Initialize the middleware.
|
|
|
|
Args:
|
|
base_dir: Base directory for thread data. Defaults to the current working directory.
|
|
"""
|
|
super().__init__()
|
|
self._base_dir = base_dir or os.getcwd()
|
|
|
|
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
|
"""Get the paths for a thread's data directories.
|
|
|
|
Args:
|
|
thread_id: The thread ID.
|
|
|
|
Returns:
|
|
Dictionary with workspace_path, uploads_path, and outputs_path.
|
|
"""
|
|
thread_dir = Path(self._base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
|
return {
|
|
"workspace_path": str(thread_dir / "workspace"),
|
|
"uploads_path": str(thread_dir / "uploads"),
|
|
"outputs_path": str(thread_dir / "outputs"),
|
|
}
|
|
|
|
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
|
"""Create the thread data directories.
|
|
|
|
Args:
|
|
thread_id: The thread ID.
|
|
|
|
Returns:
|
|
Dictionary with the created directory paths.
|
|
"""
|
|
paths = self._get_thread_paths(thread_id)
|
|
for path in paths.values():
|
|
os.makedirs(path, exist_ok=True)
|
|
return paths
|
|
|
|
@override
|
|
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
|
# Generate new thread ID and create directories
|
|
thread_id = runtime.context.get("thread_id")
|
|
if thread_id is None:
|
|
raise ValueError("Thread ID is required in the context")
|
|
paths = self._create_thread_directories(thread_id)
|
|
print(f"Created thread data directories for thread {thread_id}")
|
|
|
|
return {
|
|
"thread_data": {
|
|
**paths,
|
|
}
|
|
}
|