mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-18 12:04:45 +08:00
feat: add thread data middleware (#2)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
@@ -13,6 +15,10 @@ from .aio_sandbox import AioSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread data directory structure
|
||||
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
||||
CONTAINER_USER_DATA_DIR = "/mnt/user-data"
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
|
||||
DEFAULT_PORT = 8080
|
||||
@@ -76,12 +82,31 @@ class AioSandboxProvider(SandboxProvider):
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
def _start_container(self, sandbox_id: str, port: int) -> str:
|
||||
def _get_thread_mounts(self, thread_id: str) -> list[tuple[str, str, bool]]:
|
||||
"""Get the volume mounts for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
List of (host_path, container_path, read_only) tuples.
|
||||
"""
|
||||
base_dir = os.getcwd()
|
||||
thread_dir = Path(base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||
|
||||
return [
|
||||
(str(thread_dir / "workspace"), f"{CONTAINER_USER_DATA_DIR}/workspace", False),
|
||||
(str(thread_dir / "uploads"), f"{CONTAINER_USER_DATA_DIR}/uploads", False),
|
||||
(str(thread_dir / "outputs"), f"{CONTAINER_USER_DATA_DIR}/outputs", False),
|
||||
]
|
||||
|
||||
def _start_container(self, sandbox_id: str, port: int, extra_mounts: list[tuple[str, str, bool]] | None = None) -> str:
|
||||
"""Start a new Docker container for the sandbox.
|
||||
|
||||
Args:
|
||||
sandbox_id: Unique identifier for the sandbox.
|
||||
port: Port to expose the sandbox API on.
|
||||
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||
|
||||
Returns:
|
||||
The container ID.
|
||||
@@ -102,7 +127,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
container_name,
|
||||
]
|
||||
|
||||
# Add volume mounts
|
||||
# Add configured volume mounts
|
||||
for mount in self._config["mounts"]:
|
||||
host_path = mount.host_path
|
||||
container_path = mount.container_path
|
||||
@@ -112,6 +137,14 @@ class AioSandboxProvider(SandboxProvider):
|
||||
mount_spec += ":ro"
|
||||
cmd.extend(["-v", mount_spec])
|
||||
|
||||
# Add extra mounts (e.g., thread-specific directories)
|
||||
if extra_mounts:
|
||||
for host_path, container_path, read_only in extra_mounts:
|
||||
mount_spec = f"{host_path}:{container_path}"
|
||||
if read_only:
|
||||
mount_spec += ":ro"
|
||||
cmd.extend(["-v", mount_spec])
|
||||
|
||||
cmd.append(image)
|
||||
|
||||
logger.info(f"Starting sandbox container: {' '.join(cmd)}")
|
||||
@@ -158,17 +191,28 @@ class AioSandboxProvider(SandboxProvider):
|
||||
port += 1
|
||||
raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}")
|
||||
|
||||
def acquire(self) -> str:
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
If base_url is configured, uses the existing sandbox.
|
||||
Otherwise, starts a new Docker container.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID for thread-specific configurations.
|
||||
If provided, the sandbox will be configured with thread-specific
|
||||
mounts for workspace, uploads, and outputs directories.
|
||||
|
||||
Returns:
|
||||
The ID of the acquired sandbox environment.
|
||||
"""
|
||||
sandbox_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Get thread-specific mounts if thread_id is provided
|
||||
extra_mounts = None
|
||||
if thread_id:
|
||||
extra_mounts = self._get_thread_mounts(thread_id)
|
||||
logger.info(f"Adding thread mounts for thread {thread_id}: {extra_mounts}")
|
||||
|
||||
# If base_url is configured, use existing sandbox
|
||||
if self._config.get("base_url"):
|
||||
base_url = self._config["base_url"]
|
||||
@@ -186,7 +230,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
raise RuntimeError("auto_start is disabled and no base_url is configured")
|
||||
|
||||
port = self._find_available_port(self._config["port"])
|
||||
container_id = self._start_container(sandbox_id, port)
|
||||
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts)
|
||||
self._containers[sandbox_id] = container_id
|
||||
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
Reference in New Issue
Block a user