feat: add thread data middleware (#2)

This commit is contained in:
DanielWalnut
2026-01-15 13:22:30 +08:00
committed by GitHub
parent ab427731dc
commit c92eedc572
8 changed files with 181 additions and 14 deletions

View File

@@ -1,13 +1,15 @@
from langchain.agents import create_agent
from src.agents.lead_agent.prompt import apply_prompt_template
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
from src.agents.middlewares.title_middleware import TitleMiddleware
from src.agents.thread_state import ThreadState
from src.models import create_chat_model
from src.sandbox.middleware import SandboxMiddleware
from src.tools import get_available_tools
middlewares = [SandboxMiddleware(), TitleMiddleware()]
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
middlewares = [ThreadDataMiddleware(), SandboxMiddleware(), TitleMiddleware()]
lead_agent = create_agent(
model=create_chat_model(thinking_enabled=True),

View File

@@ -0,0 +1,83 @@
import os
from pathlib import Path
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.agents.thread_state import ThreadDataState
# Base directory for thread data (relative to backend/)
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
class ThreadDataMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
thread_data: NotRequired[ThreadDataState | None]
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
"""Create thread data directories for each thread execution.
Creates the following directory structure:
- backend/.deer-flow/threads/{thread_id}/user-data/workspace
- backend/.deer-flow/threads/{thread_id}/user-data/uploads
- backend/.deer-flow/threads/{thread_id}/user-data/outputs
"""
state_schema = ThreadDataMiddlewareState
def __init__(self, base_dir: str | None = None):
"""Initialize the middleware.
Args:
base_dir: Base directory for thread data. Defaults to the current working directory.
"""
super().__init__()
self._base_dir = base_dir or os.getcwd()
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
"""Get the paths for a thread's data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with workspace_path, uploads_path, and outputs_path.
"""
thread_dir = Path(self._base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
return {
"workspace_path": str(thread_dir / "workspace"),
"uploads_path": str(thread_dir / "uploads"),
"outputs_path": str(thread_dir / "outputs"),
}
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
"""Create the thread data directories.
Args:
thread_id: The thread ID.
Returns:
Dictionary with the created directory paths.
"""
paths = self._get_thread_paths(thread_id)
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
# Generate new thread ID and create directories
print(runtime.context)
thread_id = runtime.context["thread_id"]
paths = self._create_thread_directories(thread_id)
print(f"Created thread data directories for thread {thread_id}")
return {
"thread_data": {
**paths,
}
}

View File

@@ -7,6 +7,13 @@ class SandboxState(TypedDict):
sandbox_id: NotRequired[str | None]
class ThreadDataState(TypedDict):
workspace_path: NotRequired[str | None]
uploads_path: NotRequired[str | None]
outputs_path: NotRequired[str | None]
class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
title: NotRequired[str | None]

View File

@@ -1,7 +1,9 @@
import logging
import os
import subprocess
import time
import uuid
from pathlib import Path
import requests
@@ -13,6 +15,10 @@ from .aio_sandbox import AioSandbox
logger = logging.getLogger(__name__)
# Thread data directory structure
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
CONTAINER_USER_DATA_DIR = "/mnt/user-data"
# Default configuration
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
DEFAULT_PORT = 8080
@@ -76,12 +82,31 @@ class AioSandboxProvider(SandboxProvider):
time.sleep(1)
return False
def _start_container(self, sandbox_id: str, port: int) -> str:
def _get_thread_mounts(self, thread_id: str) -> list[tuple[str, str, bool]]:
"""Get the volume mounts for a thread's data directories.
Args:
thread_id: The thread ID.
Returns:
List of (host_path, container_path, read_only) tuples.
"""
base_dir = os.getcwd()
thread_dir = Path(base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
return [
(str(thread_dir / "workspace"), f"{CONTAINER_USER_DATA_DIR}/workspace", False),
(str(thread_dir / "uploads"), f"{CONTAINER_USER_DATA_DIR}/uploads", False),
(str(thread_dir / "outputs"), f"{CONTAINER_USER_DATA_DIR}/outputs", False),
]
def _start_container(self, sandbox_id: str, port: int, extra_mounts: list[tuple[str, str, bool]] | None = None) -> str:
"""Start a new Docker container for the sandbox.
Args:
sandbox_id: Unique identifier for the sandbox.
port: Port to expose the sandbox API on.
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
Returns:
The container ID.
@@ -102,7 +127,7 @@ class AioSandboxProvider(SandboxProvider):
container_name,
]
# Add volume mounts
# Add configured volume mounts
for mount in self._config["mounts"]:
host_path = mount.host_path
container_path = mount.container_path
@@ -112,6 +137,14 @@ class AioSandboxProvider(SandboxProvider):
mount_spec += ":ro"
cmd.extend(["-v", mount_spec])
# Add extra mounts (e.g., thread-specific directories)
if extra_mounts:
for host_path, container_path, read_only in extra_mounts:
mount_spec = f"{host_path}:{container_path}"
if read_only:
mount_spec += ":ro"
cmd.extend(["-v", mount_spec])
cmd.append(image)
logger.info(f"Starting sandbox container: {' '.join(cmd)}")
@@ -158,17 +191,28 @@ class AioSandboxProvider(SandboxProvider):
port += 1
raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}")
def acquire(self) -> str:
def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID.
If base_url is configured, uses the existing sandbox.
Otherwise, starts a new Docker container.
Args:
thread_id: Optional thread ID for thread-specific configurations.
If provided, the sandbox will be configured with thread-specific
mounts for workspace, uploads, and outputs directories.
Returns:
The ID of the acquired sandbox environment.
"""
sandbox_id = str(uuid.uuid4())[:8]
# Get thread-specific mounts if thread_id is provided
extra_mounts = None
if thread_id:
extra_mounts = self._get_thread_mounts(thread_id)
logger.info(f"Adding thread mounts for thread {thread_id}: {extra_mounts}")
# If base_url is configured, use existing sandbox
if self._config.get("base_url"):
base_url = self._config["base_url"]
@@ -186,7 +230,7 @@ class AioSandboxProvider(SandboxProvider):
raise RuntimeError("auto_start is disabled and no base_url is configured")
port = self._find_available_port(self._config["port"])
container_id = self._start_container(sandbox_id, port)
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts)
self._containers[sandbox_id] = container_id
base_url = f"http://localhost:{port}"

View File

@@ -1,12 +1,11 @@
from src.sandbox.local.local_sandbox import LocalSandbox
from src.sandbox.sandbox import Sandbox
from src.sandbox.sandbox_provider import SandboxProvider
_singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider):
def acquire(self) -> Sandbox:
def acquire(self, thread_id: str | None = None) -> str:
global _singleton
if _singleton is None:
_singleton = LocalSandbox("local")

View File

@@ -4,7 +4,7 @@ from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.agents.thread_state import SandboxState
from src.agents.thread_state import SandboxState, ThreadDataState
from src.sandbox import get_sandbox_provider
@@ -12,6 +12,7 @@ class SandboxMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
@@ -19,15 +20,17 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
state_schema = SandboxMiddlewareState
def _acquire_sandbox(self) -> str:
def _acquire_sandbox(self, thread_id: str) -> str:
provider = get_sandbox_provider()
sandbox_id = provider.acquire()
sandbox_id = provider.acquire(thread_id)
print(f"Acquiring sandbox {sandbox_id}")
return sandbox_id
@override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
if "sandbox" not in state or state["sandbox"] is None:
sandbox_id = self._acquire_sandbox()
thread_id = runtime.context["thread_id"]
print(f"Thread ID: {thread_id}")
sandbox_id = self._acquire_sandbox(thread_id)
return {"sandbox": {"sandbox_id": sandbox_id}}
return super().before_agent(state, runtime)

View File

@@ -9,7 +9,7 @@ class SandboxProvider(ABC):
"""Abstract base class for sandbox providers"""
@abstractmethod
def acquire(self) -> str:
def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID.
Returns:
@@ -39,7 +39,7 @@ class SandboxProvider(ABC):
_default_sandbox_provider: SandboxProvider | None = None
def get_sandbox_provider() -> SandboxProvider:
def get_sandbox_provider(**kwargs) -> SandboxProvider:
"""Get the sandbox provider.
Returns:
@@ -49,5 +49,5 @@ def get_sandbox_provider() -> SandboxProvider:
if _default_sandbox_provider is None:
config = get_app_config()
cls = resolve_class(config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls()
_default_sandbox_provider = cls(**kwargs)
return _default_sandbox_provider

View File

@@ -64,3 +64,32 @@ tools:
use: src.sandbox.tools:bash_tool
sandbox:
use: src.sandbox.local:LocalSandboxProvider
# To use Docker-based AIO sandbox instead, uncomment the following:
# sandbox:
# use: src.community.aio_sandbox:AioSandboxProvider
# # Optional: Use existing sandbox at this URL (no Docker container will be started)
# # base_url: http://localhost:8080
# # Optional: Docker image to use (default: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest)
# # image: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest
# # Optional: Base port for sandbox containers (default: 8080)
# # port: 8080
# # Optional: Whether to automatically start Docker container (default: true)
# # auto_start: true
# # Optional: Prefix for container names (default: deer-flow-sandbox)
# # container_prefix: deer-flow-sandbox
# # Optional: Mount directories from host to container
# # mounts:
# # - host_path: /path/on/host
# # container_path: /home/user/shared
# # read_only: false
# # - host_path: /another/path
# # container_path: /data
# # read_only: true
# Automatic thread title generation
title:
enabled: true
max_words: 6
max_chars: 60
model_name: null # Use default model