diff --git a/backend/src/agents/lead_agent/agent.py b/backend/src/agents/lead_agent/agent.py index 3c31073..0773583 100644 --- a/backend/src/agents/lead_agent/agent.py +++ b/backend/src/agents/lead_agent/agent.py @@ -1,13 +1,15 @@ from langchain.agents import create_agent from src.agents.lead_agent.prompt import apply_prompt_template +from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware from src.agents.middlewares.title_middleware import TitleMiddleware from src.agents.thread_state import ThreadState from src.models import create_chat_model from src.sandbox.middleware import SandboxMiddleware from src.tools import get_available_tools -middlewares = [SandboxMiddleware(), TitleMiddleware()] +# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available +middlewares = [ThreadDataMiddleware(), SandboxMiddleware(), TitleMiddleware()] lead_agent = create_agent( model=create_chat_model(thinking_enabled=True), diff --git a/backend/src/agents/middlewares/thread_data_middleware.py b/backend/src/agents/middlewares/thread_data_middleware.py new file mode 100644 index 0000000..727e62f --- /dev/null +++ b/backend/src/agents/middlewares/thread_data_middleware.py @@ -0,0 +1,83 @@ +import os +from pathlib import Path +from typing import NotRequired, override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langgraph.runtime import Runtime + +from src.agents.thread_state import ThreadDataState + +# Base directory for thread data (relative to backend/) +THREAD_DATA_BASE_DIR = ".deer-flow/threads" + + +class ThreadDataMiddlewareState(AgentState): + """Compatible with the `ThreadState` schema.""" + + thread_data: NotRequired[ThreadDataState | None] + + +class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]): + """Create thread data directories for each thread execution. + + Creates the following directory structure: + - backend/.deer-flow/threads/{thread_id}/user-data/workspace + - backend/.deer-flow/threads/{thread_id}/user-data/uploads + - backend/.deer-flow/threads/{thread_id}/user-data/outputs + """ + + state_schema = ThreadDataMiddlewareState + + def __init__(self, base_dir: str | None = None): + """Initialize the middleware. + + Args: + base_dir: Base directory for thread data. Defaults to the current working directory. + """ + super().__init__() + self._base_dir = base_dir or os.getcwd() + + def _get_thread_paths(self, thread_id: str) -> dict[str, str]: + """Get the paths for a thread's data directories. + + Args: + thread_id: The thread ID. + + Returns: + Dictionary with workspace_path, uploads_path, and outputs_path. + """ + thread_dir = Path(self._base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data" + return { + "workspace_path": str(thread_dir / "workspace"), + "uploads_path": str(thread_dir / "uploads"), + "outputs_path": str(thread_dir / "outputs"), + } + + def _create_thread_directories(self, thread_id: str) -> dict[str, str]: + """Create the thread data directories. + + Args: + thread_id: The thread ID. + + Returns: + Dictionary with the created directory paths. + """ + paths = self._get_thread_paths(thread_id) + for path in paths.values(): + os.makedirs(path, exist_ok=True) + return paths + + @override + def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None: + # Generate new thread ID and create directories + print(runtime.context) + thread_id = runtime.context["thread_id"] + paths = self._create_thread_directories(thread_id) + print(f"Created thread data directories for thread {thread_id}") + + return { + "thread_data": { + **paths, + } + } diff --git a/backend/src/agents/thread_state.py b/backend/src/agents/thread_state.py index fbafe15..ef4c80d 100644 --- a/backend/src/agents/thread_state.py +++ b/backend/src/agents/thread_state.py @@ -7,6 +7,13 @@ class SandboxState(TypedDict): sandbox_id: NotRequired[str | None] +class ThreadDataState(TypedDict): + workspace_path: NotRequired[str | None] + uploads_path: NotRequired[str | None] + outputs_path: NotRequired[str | None] + + class ThreadState(AgentState): sandbox: NotRequired[SandboxState | None] + thread_data: NotRequired[ThreadDataState | None] title: NotRequired[str | None] diff --git a/backend/src/community/aio_sandbox/aio_sandbox_provider.py b/backend/src/community/aio_sandbox/aio_sandbox_provider.py index c24b8a5..38d6bcc 100644 --- a/backend/src/community/aio_sandbox/aio_sandbox_provider.py +++ b/backend/src/community/aio_sandbox/aio_sandbox_provider.py @@ -1,7 +1,9 @@ import logging +import os import subprocess import time import uuid +from pathlib import Path import requests @@ -13,6 +15,10 @@ from .aio_sandbox import AioSandbox logger = logging.getLogger(__name__) +# Thread data directory structure +THREAD_DATA_BASE_DIR = ".deer-flow/threads" +CONTAINER_USER_DATA_DIR = "/mnt/user-data" + # Default configuration DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest" DEFAULT_PORT = 8080 @@ -76,12 +82,31 @@ class AioSandboxProvider(SandboxProvider): time.sleep(1) return False - def _start_container(self, sandbox_id: str, port: int) -> str: + def _get_thread_mounts(self, thread_id: str) -> list[tuple[str, str, bool]]: + """Get the volume mounts for a thread's data directories. + + Args: + thread_id: The thread ID. + + Returns: + List of (host_path, container_path, read_only) tuples. + """ + base_dir = os.getcwd() + thread_dir = Path(base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data" + + return [ + (str(thread_dir / "workspace"), f"{CONTAINER_USER_DATA_DIR}/workspace", False), + (str(thread_dir / "uploads"), f"{CONTAINER_USER_DATA_DIR}/uploads", False), + (str(thread_dir / "outputs"), f"{CONTAINER_USER_DATA_DIR}/outputs", False), + ] + + def _start_container(self, sandbox_id: str, port: int, extra_mounts: list[tuple[str, str, bool]] | None = None) -> str: """Start a new Docker container for the sandbox. Args: sandbox_id: Unique identifier for the sandbox. port: Port to expose the sandbox API on. + extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples. Returns: The container ID. @@ -102,7 +127,7 @@ class AioSandboxProvider(SandboxProvider): container_name, ] - # Add volume mounts + # Add configured volume mounts for mount in self._config["mounts"]: host_path = mount.host_path container_path = mount.container_path @@ -112,6 +137,14 @@ class AioSandboxProvider(SandboxProvider): mount_spec += ":ro" cmd.extend(["-v", mount_spec]) + # Add extra mounts (e.g., thread-specific directories) + if extra_mounts: + for host_path, container_path, read_only in extra_mounts: + mount_spec = f"{host_path}:{container_path}" + if read_only: + mount_spec += ":ro" + cmd.extend(["-v", mount_spec]) + cmd.append(image) logger.info(f"Starting sandbox container: {' '.join(cmd)}") @@ -158,17 +191,28 @@ class AioSandboxProvider(SandboxProvider): port += 1 raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}") - def acquire(self) -> str: + def acquire(self, thread_id: str | None = None) -> str: """Acquire a sandbox environment and return its ID. If base_url is configured, uses the existing sandbox. Otherwise, starts a new Docker container. + Args: + thread_id: Optional thread ID for thread-specific configurations. + If provided, the sandbox will be configured with thread-specific + mounts for workspace, uploads, and outputs directories. + Returns: The ID of the acquired sandbox environment. """ sandbox_id = str(uuid.uuid4())[:8] + # Get thread-specific mounts if thread_id is provided + extra_mounts = None + if thread_id: + extra_mounts = self._get_thread_mounts(thread_id) + logger.info(f"Adding thread mounts for thread {thread_id}: {extra_mounts}") + # If base_url is configured, use existing sandbox if self._config.get("base_url"): base_url = self._config["base_url"] @@ -186,7 +230,7 @@ class AioSandboxProvider(SandboxProvider): raise RuntimeError("auto_start is disabled and no base_url is configured") port = self._find_available_port(self._config["port"]) - container_id = self._start_container(sandbox_id, port) + container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts) self._containers[sandbox_id] = container_id base_url = f"http://localhost:{port}" diff --git a/backend/src/sandbox/local/local_sandbox_provider.py b/backend/src/sandbox/local/local_sandbox_provider.py index 467e8f0..bb96098 100644 --- a/backend/src/sandbox/local/local_sandbox_provider.py +++ b/backend/src/sandbox/local/local_sandbox_provider.py @@ -1,12 +1,11 @@ from src.sandbox.local.local_sandbox import LocalSandbox -from src.sandbox.sandbox import Sandbox from src.sandbox.sandbox_provider import SandboxProvider _singleton: LocalSandbox | None = None class LocalSandboxProvider(SandboxProvider): - def acquire(self) -> Sandbox: + def acquire(self, thread_id: str | None = None) -> str: global _singleton if _singleton is None: _singleton = LocalSandbox("local") diff --git a/backend/src/sandbox/middleware.py b/backend/src/sandbox/middleware.py index 630d7f2..8c056a7 100644 --- a/backend/src/sandbox/middleware.py +++ b/backend/src/sandbox/middleware.py @@ -4,7 +4,7 @@ from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime -from src.agents.thread_state import SandboxState +from src.agents.thread_state import SandboxState, ThreadDataState from src.sandbox import get_sandbox_provider @@ -12,6 +12,7 @@ class SandboxMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" sandbox: NotRequired[SandboxState | None] + thread_data: NotRequired[ThreadDataState | None] class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): @@ -19,15 +20,17 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): state_schema = SandboxMiddlewareState - def _acquire_sandbox(self) -> str: + def _acquire_sandbox(self, thread_id: str) -> str: provider = get_sandbox_provider() - sandbox_id = provider.acquire() + sandbox_id = provider.acquire(thread_id) print(f"Acquiring sandbox {sandbox_id}") return sandbox_id @override def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: if "sandbox" not in state or state["sandbox"] is None: - sandbox_id = self._acquire_sandbox() + thread_id = runtime.context["thread_id"] + print(f"Thread ID: {thread_id}") + sandbox_id = self._acquire_sandbox(thread_id) return {"sandbox": {"sandbox_id": sandbox_id}} return super().before_agent(state, runtime) diff --git a/backend/src/sandbox/sandbox_provider.py b/backend/src/sandbox/sandbox_provider.py index 5ca8c81..722cc95 100644 --- a/backend/src/sandbox/sandbox_provider.py +++ b/backend/src/sandbox/sandbox_provider.py @@ -9,7 +9,7 @@ class SandboxProvider(ABC): """Abstract base class for sandbox providers""" @abstractmethod - def acquire(self) -> str: + def acquire(self, thread_id: str | None = None) -> str: """Acquire a sandbox environment and return its ID. Returns: @@ -39,7 +39,7 @@ class SandboxProvider(ABC): _default_sandbox_provider: SandboxProvider | None = None -def get_sandbox_provider() -> SandboxProvider: +def get_sandbox_provider(**kwargs) -> SandboxProvider: """Get the sandbox provider. Returns: @@ -49,5 +49,5 @@ def get_sandbox_provider() -> SandboxProvider: if _default_sandbox_provider is None: config = get_app_config() cls = resolve_class(config.sandbox.use, SandboxProvider) - _default_sandbox_provider = cls() + _default_sandbox_provider = cls(**kwargs) return _default_sandbox_provider diff --git a/config.example.yaml b/config.example.yaml index 1a38204..20a5d18 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -64,3 +64,32 @@ tools: use: src.sandbox.tools:bash_tool sandbox: use: src.sandbox.local:LocalSandboxProvider + +# To use Docker-based AIO sandbox instead, uncomment the following: +# sandbox: +# use: src.community.aio_sandbox:AioSandboxProvider +# # Optional: Use existing sandbox at this URL (no Docker container will be started) +# # base_url: http://localhost:8080 +# # Optional: Docker image to use (default: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest) +# # image: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest +# # Optional: Base port for sandbox containers (default: 8080) +# # port: 8080 +# # Optional: Whether to automatically start Docker container (default: true) +# # auto_start: true +# # Optional: Prefix for container names (default: deer-flow-sandbox) +# # container_prefix: deer-flow-sandbox +# # Optional: Mount directories from host to container +# # mounts: +# # - host_path: /path/on/host +# # container_path: /home/user/shared +# # read_only: false +# # - host_path: /another/path +# # container_path: /data +# # read_only: true + +# Automatic thread title generation +title: + enabled: true + max_words: 6 + max_chars: 60 + model_name: null # Use default model \ No newline at end of file