feat: add thread data middleware (#2)

This commit is contained in:
DanielWalnut
2026-01-15 13:22:30 +08:00
committed by GitHub
parent ab427731dc
commit c92eedc572
8 changed files with 181 additions and 14 deletions

View File

@@ -1,13 +1,15 @@
from langchain.agents import create_agent
from src.agents.lead_agent.prompt import apply_prompt_template
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from src.agents.middlewares.title_middleware import TitleMiddleware
from src.agents.thread_state import ThreadState
from src.models import create_chat_model
from src.sandbox.middleware import SandboxMiddleware
from src.tools import get_available_tools
middlewares = [SandboxMiddleware(), TitleMiddleware()]
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
middlewares = [ThreadDataMiddleware(), SandboxMiddleware(), TitleMiddleware()]
lead_agent = create_agent(
model=create_chat_model(thinking_enabled=True),

View File

@@ -0,0 +1,83 @@
import os
from pathlib import Path
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.agents.thread_state import ThreadDataState
# Base directory for thread data (relative to backend/)
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
class ThreadDataMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
thread_data: NotRequired[ThreadDataState | None]
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
"""Create thread data directories for each thread execution.
Creates the following directory structure:
- backend/.deer-flow/threads/{thread_id}/user-data/workspace
- backend/.deer-flow/threads/{thread_id}/user-data/uploads
- backend/.deer-flow/threads/{thread_id}/user-data/outputs
"""
state_schema = ThreadDataMiddlewareState
def __init__(self, base_dir: str | None = None):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to the current working directory.
"""
super().__init__()
self._base_dir = base_dir or os.getcwd()
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
"""Get the paths for a thread's data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with workspace_path, uploads_path, and outputs_path.
"""
thread_dir = Path(self._base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
return {
"workspace_path": str(thread_dir / "workspace"),
"uploads_path": str(thread_dir / "uploads"),
"outputs_path": str(thread_dir / "outputs"),
}
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
"""Create the thread data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with the created directory paths.
"""
paths = self._get_thread_paths(thread_id)
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
# Generate new thread ID and create directories
print(runtime.context)
thread_id = runtime.context["thread_id"]
paths = self._create_thread_directories(thread_id)
print(f"Created thread data directories for thread {thread_id}")
return {
"thread_data": {
**paths,
}
}

View File

@@ -7,6 +7,13 @@ class SandboxState(TypedDict):
sandbox_id: NotRequired[str | None]
class ThreadDataState(TypedDict):
workspace_path: NotRequired[str | None]
uploads_path: NotRequired[str | None]
outputs_path: NotRequired[str | None]
class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
title: NotRequired[str | None]