mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
feat: add thread data middleware (#2)
This commit is contained in:
@@ -1,13 +1,15 @@
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||||
|
from src.agents.middlewares.thread_data_middleware import ThreadDataMiddleware
|
||||||
from src.agents.middlewares.title_middleware import TitleMiddleware
|
from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||||
from src.agents.thread_state import ThreadState
|
from src.agents.thread_state import ThreadState
|
||||||
from src.models import create_chat_model
|
from src.models import create_chat_model
|
||||||
from src.sandbox.middleware import SandboxMiddleware
|
from src.sandbox.middleware import SandboxMiddleware
|
||||||
from src.tools import get_available_tools
|
from src.tools import get_available_tools
|
||||||
|
|
||||||
middlewares = [SandboxMiddleware(), TitleMiddleware()]
|
# ThreadDataMiddleware must be before SandboxMiddleware to ensure thread_id is available
|
||||||
|
middlewares = [ThreadDataMiddleware(), SandboxMiddleware(), TitleMiddleware()]
|
||||||
|
|
||||||
lead_agent = create_agent(
|
lead_agent = create_agent(
|
||||||
model=create_chat_model(thinking_enabled=True),
|
model=create_chat_model(thinking_enabled=True),
|
||||||
|
|||||||
83
backend/src/agents/middlewares/thread_data_middleware.py
Normal file
83
backend/src/agents/middlewares/thread_data_middleware.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import NotRequired, override
|
||||||
|
|
||||||
|
from langchain.agents import AgentState
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from src.agents.thread_state import ThreadDataState
|
||||||
|
|
||||||
|
# Base directory for thread data (relative to backend/)
|
||||||
|
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDataMiddlewareState(AgentState):
|
||||||
|
"""Compatible with the `ThreadState` schema."""
|
||||||
|
|
||||||
|
thread_data: NotRequired[ThreadDataState | None]
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||||
|
"""Create thread data directories for each thread execution.
|
||||||
|
|
||||||
|
Creates the following directory structure:
|
||||||
|
- backend/.deer-flow/threads/{thread_id}/user-data/workspace
|
||||||
|
- backend/.deer-flow/threads/{thread_id}/user-data/uploads
|
||||||
|
- backend/.deer-flow/threads/{thread_id}/user-data/outputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
state_schema = ThreadDataMiddlewareState
|
||||||
|
|
||||||
|
def __init__(self, base_dir: str | None = None):
|
||||||
|
"""Initialize the middleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_dir: Base directory for thread data. Defaults to the current working directory.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._base_dir = base_dir or os.getcwd()
|
||||||
|
|
||||||
|
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||||
|
"""Get the paths for a thread's data directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The thread ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||||
|
"""
|
||||||
|
thread_dir = Path(self._base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||||
|
return {
|
||||||
|
"workspace_path": str(thread_dir / "workspace"),
|
||||||
|
"uploads_path": str(thread_dir / "uploads"),
|
||||||
|
"outputs_path": str(thread_dir / "outputs"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||||
|
"""Create the thread data directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The thread ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with the created directory paths.
|
||||||
|
"""
|
||||||
|
paths = self._get_thread_paths(thread_id)
|
||||||
|
for path in paths.values():
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
@override
|
||||||
|
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||||
|
# Generate new thread ID and create directories
|
||||||
|
print(runtime.context)
|
||||||
|
thread_id = runtime.context["thread_id"]
|
||||||
|
paths = self._create_thread_directories(thread_id)
|
||||||
|
print(f"Created thread data directories for thread {thread_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"thread_data": {
|
||||||
|
**paths,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,13 @@ class SandboxState(TypedDict):
|
|||||||
sandbox_id: NotRequired[str | None]
|
sandbox_id: NotRequired[str | None]
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDataState(TypedDict):
|
||||||
|
workspace_path: NotRequired[str | None]
|
||||||
|
uploads_path: NotRequired[str | None]
|
||||||
|
outputs_path: NotRequired[str | None]
|
||||||
|
|
||||||
|
|
||||||
class ThreadState(AgentState):
|
class ThreadState(AgentState):
|
||||||
sandbox: NotRequired[SandboxState | None]
|
sandbox: NotRequired[SandboxState | None]
|
||||||
|
thread_data: NotRequired[ThreadDataState | None]
|
||||||
title: NotRequired[str | None]
|
title: NotRequired[str | None]
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -13,6 +15,10 @@ from .aio_sandbox import AioSandbox
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Thread data directory structure
|
||||||
|
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
||||||
|
CONTAINER_USER_DATA_DIR = "/mnt/user-data"
|
||||||
|
|
||||||
# Default configuration
|
# Default configuration
|
||||||
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
|
DEFAULT_IMAGE = "enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest"
|
||||||
DEFAULT_PORT = 8080
|
DEFAULT_PORT = 8080
|
||||||
@@ -76,12 +82,31 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _start_container(self, sandbox_id: str, port: int) -> str:
|
def _get_thread_mounts(self, thread_id: str) -> list[tuple[str, str, bool]]:
|
||||||
|
"""Get the volume mounts for a thread's data directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The thread ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (host_path, container_path, read_only) tuples.
|
||||||
|
"""
|
||||||
|
base_dir = os.getcwd()
|
||||||
|
thread_dir = Path(base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||||
|
|
||||||
|
return [
|
||||||
|
(str(thread_dir / "workspace"), f"{CONTAINER_USER_DATA_DIR}/workspace", False),
|
||||||
|
(str(thread_dir / "uploads"), f"{CONTAINER_USER_DATA_DIR}/uploads", False),
|
||||||
|
(str(thread_dir / "outputs"), f"{CONTAINER_USER_DATA_DIR}/outputs", False),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _start_container(self, sandbox_id: str, port: int, extra_mounts: list[tuple[str, str, bool]] | None = None) -> str:
|
||||||
"""Start a new Docker container for the sandbox.
|
"""Start a new Docker container for the sandbox.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sandbox_id: Unique identifier for the sandbox.
|
sandbox_id: Unique identifier for the sandbox.
|
||||||
port: Port to expose the sandbox API on.
|
port: Port to expose the sandbox API on.
|
||||||
|
extra_mounts: Additional volume mounts as (host_path, container_path, read_only) tuples.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The container ID.
|
The container ID.
|
||||||
@@ -102,7 +127,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
container_name,
|
container_name,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add volume mounts
|
# Add configured volume mounts
|
||||||
for mount in self._config["mounts"]:
|
for mount in self._config["mounts"]:
|
||||||
host_path = mount.host_path
|
host_path = mount.host_path
|
||||||
container_path = mount.container_path
|
container_path = mount.container_path
|
||||||
@@ -112,6 +137,14 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
mount_spec += ":ro"
|
mount_spec += ":ro"
|
||||||
cmd.extend(["-v", mount_spec])
|
cmd.extend(["-v", mount_spec])
|
||||||
|
|
||||||
|
# Add extra mounts (e.g., thread-specific directories)
|
||||||
|
if extra_mounts:
|
||||||
|
for host_path, container_path, read_only in extra_mounts:
|
||||||
|
mount_spec = f"{host_path}:{container_path}"
|
||||||
|
if read_only:
|
||||||
|
mount_spec += ":ro"
|
||||||
|
cmd.extend(["-v", mount_spec])
|
||||||
|
|
||||||
cmd.append(image)
|
cmd.append(image)
|
||||||
|
|
||||||
logger.info(f"Starting sandbox container: {' '.join(cmd)}")
|
logger.info(f"Starting sandbox container: {' '.join(cmd)}")
|
||||||
@@ -158,17 +191,28 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
port += 1
|
port += 1
|
||||||
raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}")
|
raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}")
|
||||||
|
|
||||||
def acquire(self) -> str:
|
def acquire(self, thread_id: str | None = None) -> str:
|
||||||
"""Acquire a sandbox environment and return its ID.
|
"""Acquire a sandbox environment and return its ID.
|
||||||
|
|
||||||
If base_url is configured, uses the existing sandbox.
|
If base_url is configured, uses the existing sandbox.
|
||||||
Otherwise, starts a new Docker container.
|
Otherwise, starts a new Docker container.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Optional thread ID for thread-specific configurations.
|
||||||
|
If provided, the sandbox will be configured with thread-specific
|
||||||
|
mounts for workspace, uploads, and outputs directories.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The ID of the acquired sandbox environment.
|
The ID of the acquired sandbox environment.
|
||||||
"""
|
"""
|
||||||
sandbox_id = str(uuid.uuid4())[:8]
|
sandbox_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
# Get thread-specific mounts if thread_id is provided
|
||||||
|
extra_mounts = None
|
||||||
|
if thread_id:
|
||||||
|
extra_mounts = self._get_thread_mounts(thread_id)
|
||||||
|
logger.info(f"Adding thread mounts for thread {thread_id}: {extra_mounts}")
|
||||||
|
|
||||||
# If base_url is configured, use existing sandbox
|
# If base_url is configured, use existing sandbox
|
||||||
if self._config.get("base_url"):
|
if self._config.get("base_url"):
|
||||||
base_url = self._config["base_url"]
|
base_url = self._config["base_url"]
|
||||||
@@ -186,7 +230,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
raise RuntimeError("auto_start is disabled and no base_url is configured")
|
raise RuntimeError("auto_start is disabled and no base_url is configured")
|
||||||
|
|
||||||
port = self._find_available_port(self._config["port"])
|
port = self._find_available_port(self._config["port"])
|
||||||
container_id = self._start_container(sandbox_id, port)
|
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts)
|
||||||
self._containers[sandbox_id] = container_id
|
self._containers[sandbox_id] = container_id
|
||||||
|
|
||||||
base_url = f"http://localhost:{port}"
|
base_url = f"http://localhost:{port}"
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
from src.sandbox.local.local_sandbox import LocalSandbox
|
from src.sandbox.local.local_sandbox import LocalSandbox
|
||||||
from src.sandbox.sandbox import Sandbox
|
|
||||||
from src.sandbox.sandbox_provider import SandboxProvider
|
from src.sandbox.sandbox_provider import SandboxProvider
|
||||||
|
|
||||||
_singleton: LocalSandbox | None = None
|
_singleton: LocalSandbox | None = None
|
||||||
|
|
||||||
|
|
||||||
class LocalSandboxProvider(SandboxProvider):
|
class LocalSandboxProvider(SandboxProvider):
|
||||||
def acquire(self) -> Sandbox:
|
def acquire(self, thread_id: str | None = None) -> str:
|
||||||
global _singleton
|
global _singleton
|
||||||
if _singleton is None:
|
if _singleton is None:
|
||||||
_singleton = LocalSandbox("local")
|
_singleton = LocalSandbox("local")
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from langchain.agents import AgentState
|
|||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
from langgraph.runtime import Runtime
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
from src.agents.thread_state import SandboxState
|
from src.agents.thread_state import SandboxState, ThreadDataState
|
||||||
from src.sandbox import get_sandbox_provider
|
from src.sandbox import get_sandbox_provider
|
||||||
|
|
||||||
|
|
||||||
@@ -12,6 +12,7 @@ class SandboxMiddlewareState(AgentState):
|
|||||||
"""Compatible with the `ThreadState` schema."""
|
"""Compatible with the `ThreadState` schema."""
|
||||||
|
|
||||||
sandbox: NotRequired[SandboxState | None]
|
sandbox: NotRequired[SandboxState | None]
|
||||||
|
thread_data: NotRequired[ThreadDataState | None]
|
||||||
|
|
||||||
|
|
||||||
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||||
@@ -19,15 +20,17 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
|||||||
|
|
||||||
state_schema = SandboxMiddlewareState
|
state_schema = SandboxMiddlewareState
|
||||||
|
|
||||||
def _acquire_sandbox(self) -> str:
|
def _acquire_sandbox(self, thread_id: str) -> str:
|
||||||
provider = get_sandbox_provider()
|
provider = get_sandbox_provider()
|
||||||
sandbox_id = provider.acquire()
|
sandbox_id = provider.acquire(thread_id)
|
||||||
print(f"Acquiring sandbox {sandbox_id}")
|
print(f"Acquiring sandbox {sandbox_id}")
|
||||||
return sandbox_id
|
return sandbox_id
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||||
if "sandbox" not in state or state["sandbox"] is None:
|
if "sandbox" not in state or state["sandbox"] is None:
|
||||||
sandbox_id = self._acquire_sandbox()
|
thread_id = runtime.context["thread_id"]
|
||||||
|
print(f"Thread ID: {thread_id}")
|
||||||
|
sandbox_id = self._acquire_sandbox(thread_id)
|
||||||
return {"sandbox": {"sandbox_id": sandbox_id}}
|
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||||
return super().before_agent(state, runtime)
|
return super().before_agent(state, runtime)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class SandboxProvider(ABC):
|
|||||||
"""Abstract base class for sandbox providers"""
|
"""Abstract base class for sandbox providers"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def acquire(self) -> str:
|
def acquire(self, thread_id: str | None = None) -> str:
|
||||||
"""Acquire a sandbox environment and return its ID.
|
"""Acquire a sandbox environment and return its ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -39,7 +39,7 @@ class SandboxProvider(ABC):
|
|||||||
_default_sandbox_provider: SandboxProvider | None = None
|
_default_sandbox_provider: SandboxProvider | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_sandbox_provider() -> SandboxProvider:
|
def get_sandbox_provider(**kwargs) -> SandboxProvider:
|
||||||
"""Get the sandbox provider.
|
"""Get the sandbox provider.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -49,5 +49,5 @@ def get_sandbox_provider() -> SandboxProvider:
|
|||||||
if _default_sandbox_provider is None:
|
if _default_sandbox_provider is None:
|
||||||
config = get_app_config()
|
config = get_app_config()
|
||||||
cls = resolve_class(config.sandbox.use, SandboxProvider)
|
cls = resolve_class(config.sandbox.use, SandboxProvider)
|
||||||
_default_sandbox_provider = cls()
|
_default_sandbox_provider = cls(**kwargs)
|
||||||
return _default_sandbox_provider
|
return _default_sandbox_provider
|
||||||
|
|||||||
@@ -64,3 +64,32 @@ tools:
|
|||||||
use: src.sandbox.tools:bash_tool
|
use: src.sandbox.tools:bash_tool
|
||||||
sandbox:
|
sandbox:
|
||||||
use: src.sandbox.local:LocalSandboxProvider
|
use: src.sandbox.local:LocalSandboxProvider
|
||||||
|
|
||||||
|
# To use Docker-based AIO sandbox instead, uncomment the following:
|
||||||
|
# sandbox:
|
||||||
|
# use: src.community.aio_sandbox:AioSandboxProvider
|
||||||
|
# # Optional: Use existing sandbox at this URL (no Docker container will be started)
|
||||||
|
# # base_url: http://localhost:8080
|
||||||
|
# # Optional: Docker image to use (default: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest)
|
||||||
|
# # image: enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest
|
||||||
|
# # Optional: Base port for sandbox containers (default: 8080)
|
||||||
|
# # port: 8080
|
||||||
|
# # Optional: Whether to automatically start Docker container (default: true)
|
||||||
|
# # auto_start: true
|
||||||
|
# # Optional: Prefix for container names (default: deer-flow-sandbox)
|
||||||
|
# # container_prefix: deer-flow-sandbox
|
||||||
|
# # Optional: Mount directories from host to container
|
||||||
|
# # mounts:
|
||||||
|
# # - host_path: /path/on/host
|
||||||
|
# # container_path: /home/user/shared
|
||||||
|
# # read_only: false
|
||||||
|
# # - host_path: /another/path
|
||||||
|
# # container_path: /data
|
||||||
|
# # read_only: true
|
||||||
|
|
||||||
|
# Automatic thread title generation
|
||||||
|
title:
|
||||||
|
enabled: true
|
||||||
|
max_words: 6
|
||||||
|
max_chars: 60
|
||||||
|
model_name: null # Use default model
|
||||||
Reference in New Issue
Block a user