mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
feat: add thread data middleware (#2)
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
from langchain.agents import create_agent
|
||||
|
||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.models import create_chat_model
|
||||
from src.sandbox.middleware import SandboxMiddleware
|
||||
from src.tools import get_available_tools
|
||||
|
||||
middlewares = [SandboxMiddleware(), TitleMiddleware()]
|
||||
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
|
||||
middlewares = [ThreadDataMiddleware(), SandboxMiddleware(), TitleMiddleware()]
|
||||
|
||||
lead_agent = create_agent(
|
||||
model=create_chat_model(thinking_enabled=True),
|
||||
|
||||
83
backend/src/agents/middlewares/thread_data_middleware.py
Normal file
83
backend/src/agents/middlewares/thread_data_middleware.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.agents.thread_state import ThreadDataState
|
||||
|
||||
# Base directory for thread data (relative to backend/)
|
||||
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
||||
|
||||
|
||||
class ThreadDataMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
"""Create thread data directories for each thread execution.
|
||||
|
||||
Creates the following directory structure:
|
||||
- backend/.deer-flow/threads/{thread_id}/user-data/workspace
|
||||
- backend/.deer-flow/threads/{thread_id}/user-data/uploads
|
||||
- backend/.deer-flow/threads/{thread_id}/user-data/outputs
|
||||
"""
|
||||
|
||||
state_schema = ThreadDataMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to the current working directory.
|
||||
"""
|
||||
super().__init__()
|
||||
self._base_dir = base_dir or os.getcwd()
|
||||
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||
"""
|
||||
thread_dir = Path(self._base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||
return {
|
||||
"workspace_path": str(thread_dir / "workspace"),
|
||||
"uploads_path": str(thread_dir / "uploads"),
|
||||
"outputs_path": str(thread_dir / "outputs"),
|
||||
}
|
||||
|
||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||
"""Create the thread data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with the created directory paths.
|
||||
"""
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
for path in paths.values():
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return paths
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
# Generate new thread ID and create directories
|
||||
print(runtime.context)
|
||||
thread_id = runtime.context["thread_id"]
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
print(f"Created thread data directories for thread {thread_id}")
|
||||
|
||||
return {
|
||||
"thread_data": {
|
||||
**paths,
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,13 @@ class SandboxState(TypedDict):
|
||||
sandbox_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class ThreadDataState(TypedDict):
|
||||
workspace_path: NotRequired[str | None]
|
||||
uploads_path: NotRequired[str | None]
|
||||
outputs_path: NotRequired[str | None]
|
||||
|
||||
|
||||
class ThreadState(AgentState):
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
title: NotRequired[str | None]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
@@ -13,6 +15,10 @@ from .aio_sandbox import AioSandbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread data directory structure
|
||||
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
||||
CONTAINER_USER_DATA_DIR = "/mnt/user-data"
|
||||
|
||||
# Default configuration
|
||||
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
|
||||
DEFAULT_PORT = 8080
|
||||
@@ -76,12 +82,31 @@ class AioSandboxProvider(SandboxProvider):
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
def _start_container(self, sandbox_id: str, port: int) -> str:
|
||||
def _get_thread_mounts(self, thread_id: str) -> list[tuple[str, str, bool]]:
|
||||
"""Get the volume mounts for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
Returns:
|
||||
List of (host_path, container_path, read_only) tuples.
|
||||
"""
|
||||
base_dir = os.getcwd()
|
||||
thread_dir = Path(base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||
|
||||
return [
|
||||
(str(thread_dir / "workspace"), f"{CONTAINER_USER_DATA_DIR}/workspace", False),
|
||||
(str(thread_dir / "uploads"), f"{CONTAINER_USER_DATA_DIR}/uploads", False),
|
||||
(str(thread_dir / "outputs"), f"{CONTAINER_USER_DATA_DIR}/outputs", False),
|
||||
]
|
||||
|
||||
def _start_container(self, sandbox_id: str, port: int, extra_mounts: list[tuple[str, str, bool]] | None = None) -> str:
|
||||
"""Start a new Docker container for the sandbox.
|
||||
|
||||
Args:
|
||||
sandbox_id: Unique identifier for the sandbox.
|
||||
port: Port to expose the sandbox API on.
|
||||
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||
|
||||
Returns:
|
||||
The container ID.
|
||||
@@ -102,7 +127,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
container_name,
|
||||
]
|
||||
|
||||
# Add volume mounts
|
||||
# Add configured volume mounts
|
||||
for mount in self._config["mounts"]:
|
||||
host_path = mount.host_path
|
||||
container_path = mount.container_path
|
||||
@@ -112,6 +137,14 @@ class AioSandboxProvider(SandboxProvider):
|
||||
mount_spec += ":ro"
|
||||
cmd.extend(["-v", mount_spec])
|
||||
|
||||
# Add extra mounts (e.g., thread-specific directories)
|
||||
if extra_mounts:
|
||||
for host_path, container_path, read_only in extra_mounts:
|
||||
mount_spec = f"{host_path}:{container_path}"
|
||||
if read_only:
|
||||
mount_spec += ":ro"
|
||||
cmd.extend(["-v", mount_spec])
|
||||
|
||||
cmd.append(image)
|
||||
|
||||
logger.info(f"Starting sandbox container: {' '.join(cmd)}")
|
||||
@@ -158,17 +191,28 @@ class AioSandboxProvider(SandboxProvider):
|
||||
port += 1
|
||||
raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}")
|
||||
|
||||
def acquire(self) -> str:
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
If base_url is configured, uses the existing sandbox.
|
||||
Otherwise, starts a new Docker container.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID for thread-specific configurations.
|
||||
If provided, the sandbox will be configured with thread-specific
|
||||
mounts for workspace, uploads, and outputs directories.
|
||||
|
||||
Returns:
|
||||
The ID of the acquired sandbox environment.
|
||||
"""
|
||||
sandbox_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Get thread-specific mounts if thread_id is provided
|
||||
extra_mounts = None
|
||||
if thread_id:
|
||||
extra_mounts = self._get_thread_mounts(thread_id)
|
||||
logger.info(f"Adding thread mounts for thread {thread_id}: {extra_mounts}")
|
||||
|
||||
# If base_url is configured, use existing sandbox
|
||||
if self._config.get("base_url"):
|
||||
base_url = self._config["base_url"]
|
||||
@@ -186,7 +230,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
raise RuntimeError("auto_start is disabled and no base_url is configured")
|
||||
|
||||
port = self._find_available_port(self._config["port"])
|
||||
container_id = self._start_container(sandbox_id, port)
|
||||
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts)
|
||||
self._containers[sandbox_id] = container_id
|
||||
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from src.sandbox.local.local_sandbox import LocalSandbox
|
||||
from src.sandbox.sandbox import Sandbox
|
||||
from src.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
_singleton: LocalSandbox | None = None
|
||||
|
||||
|
||||
class LocalSandboxProvider(SandboxProvider):
|
||||
def acquire(self) -> Sandbox:
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
global _singleton
|
||||
if _singleton is None:
|
||||
_singleton = LocalSandbox("local")
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.agents.thread_state import SandboxState
|
||||
from src.agents.thread_state import SandboxState, ThreadDataState
|
||||
from src.sandbox import get_sandbox_provider
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ class SandboxMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
@@ -19,15 +20,17 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
|
||||
state_schema = SandboxMiddlewareState
|
||||
|
||||
def _acquire_sandbox(self) -> str:
|
||||
def _acquire_sandbox(self, thread_id: str) -> str:
|
||||
provider = get_sandbox_provider()
|
||||
sandbox_id = provider.acquire()
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
print(f"Acquiring sandbox {sandbox_id}")
|
||||
return sandbox_id
|
||||
|
||||
@override
|
||||
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
if "sandbox" not in state or state["sandbox"] is None:
|
||||
sandbox_id = self._acquire_sandbox()
|
||||
thread_id = runtime.context["thread_id"]
|
||||
print(f"Thread ID: {thread_id}")
|
||||
sandbox_id = self._acquire_sandbox(thread_id)
|
||||
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||
return super().before_agent(state, runtime)
|
||||
|
||||
@@ -9,7 +9,7 @@ class SandboxProvider(ABC):
|
||||
"""Abstract base class for sandbox providers"""
|
||||
|
||||
@abstractmethod
|
||||
def acquire(self) -> str:
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
Returns:
|
||||
@@ -39,7 +39,7 @@ class SandboxProvider(ABC):
|
||||
_default_sandbox_provider: SandboxProvider | None = None
|
||||
|
||||
|
||||
def get_sandbox_provider() -> SandboxProvider:
|
||||
def get_sandbox_provider(**kwargs) -> SandboxProvider:
|
||||
"""Get the sandbox provider.
|
||||
|
||||
Returns:
|
||||
@@ -49,5 +49,5 @@ def get_sandbox_provider() -> SandboxProvider:
|
||||
if _default_sandbox_provider is None:
|
||||
config = get_app_config()
|
||||
cls = resolve_class(config.sandbox.use, SandboxProvider)
|
||||
_default_sandbox_provider = cls()
|
||||
_default_sandbox_provider = cls(**kwargs)
|
||||
return _default_sandbox_provider
|
||||
|
||||
@@ -64,3 +64,32 @@ tools:
|
||||
use: src.sandbox.tools:bash_tool
|
||||
sandbox:
|
||||
use: src.sandbox.local:LocalSandboxProvider
|
||||
|
||||
# To use Docker-based AIO sandbox instead, uncomment the following:
|
||||
# sandbox:
|
||||
# use: src.community.aio_sandbox:AioSandboxProvider
|
||||
# # Optional: Use existing sandbox at this URL (no Docker container will be started)
|
||||
# # base_url: http://localhost:8080
|
||||
# # Optional: Docker image to use (default: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest)
|
||||
# # image: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest
|
||||
# # Optional: Base port for sandbox containers (default: 8080)
|
||||
# # port: 8080
|
||||
# # Optional: Whether to automatically start Docker container (default: true)
|
||||
# # auto_start: true
|
||||
# # Optional: Prefix for container names (default: deer-flow-sandbox)
|
||||
# # container_prefix: deer-flow-sandbox
|
||||
# # Optional: Mount directories from host to container
|
||||
# # mounts:
|
||||
# # - host_path: /path/on/host
|
||||
# # container_path: /home/user/shared
|
||||
# # read_only: false
|
||||
# # - host_path: /another/path
|
||||
# # container_path: /data
|
||||
# # read_only: true
|
||||
|
||||
# Automatic thread title generation
|
||||
title:
|
||||
enabled: true
|
||||
max_words: 6
|
||||
max_chars: 60
|
||||
model_name: null # Use default model
|
||||
Reference in New Issue
Block a user