mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat: implement lazy sandbox and thread data initialization (#11)
Defer sandbox acquisition and thread directory creation until first use to improve performance and reduce resource usage. Changes: - Add lazy_init parameter to SandboxMiddleware (default: true) - Add ensure_sandbox_initialized() helper for lazy sandbox acquisition - Update all sandbox tools to use lazy initialization - Add lazy_init parameter to ThreadDataMiddleware (default: true) - Create thread directories on-demand in AioSandboxProvider - LocalSandbox already creates directories on write (no changes needed) Benefits: - Saves 1-2s Docker container startup for conversations without tools - Reduces unnecessary directory creation and file system operations - Backward compatible with lazy_init=false option Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -25,18 +25,26 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
- backend/.deer-flow/threads/{thread_id}/user-data/workspace
|
||||
- backend/.deer-flow/threads/{thread_id}/user-data/uploads
|
||||
- backend/.deer-flow/threads/{thread_id}/user-data/outputs
|
||||
|
||||
Lifecycle Management:
|
||||
- With lazy_init=True (default): Only compute paths, directories created on-demand
|
||||
- With lazy_init=False: Eagerly create directories in before_agent()
|
||||
"""
|
||||
|
||||
state_schema = ThreadDataMiddlewareState
|
||||
|
||||
def __init__(self, base_dir: str | None = None):
|
||||
def __init__(self, base_dir: str | None = None, lazy_init: bool = True):
|
||||
"""Initialize the middleware.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for thread data. Defaults to the current working directory.
|
||||
lazy_init: If True, defer directory creation until needed.
|
||||
If False, create directories eagerly in before_agent().
|
||||
Default is True for optimal performance.
|
||||
"""
|
||||
super().__init__()
|
||||
self._base_dir = base_dir or os.getcwd()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
@@ -70,12 +78,17 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
# Generate new thread ID and create directories
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("Thread ID is required in the context")
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
print(f"Created thread data directories for thread {thread_id}")
|
||||
|
||||
if self._lazy_init:
|
||||
# Lazy initialization: only compute paths, don't create directories
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
else:
|
||||
# Eager initialization: create directories immediately
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
print(f"Created thread data directories for thread {thread_id}")
|
||||
|
||||
return {
|
||||
"thread_data": {
|
||||
|
||||
@@ -120,6 +120,8 @@ class AioSandboxProvider(SandboxProvider):
|
||||
def _get_thread_mounts(self, thread_id: str) -> list[tuple[str, str, bool]]:
|
||||
"""Get the volume mounts for a thread's data directories.
|
||||
|
||||
Creates the directories if they don't exist (lazy initialization).
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
|
||||
@@ -129,12 +131,19 @@ class AioSandboxProvider(SandboxProvider):
|
||||
base_dir = os.getcwd()
|
||||
thread_dir = Path(base_dir) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||
|
||||
return [
|
||||
# Create directories for Docker volume mounts (required before container starts)
|
||||
mounts = [
|
||||
(str(thread_dir / "workspace"), f"{CONTAINER_USER_DATA_DIR}/workspace", False),
|
||||
(str(thread_dir / "uploads"), f"{CONTAINER_USER_DATA_DIR}/uploads", False),
|
||||
(str(thread_dir / "outputs"), f"{CONTAINER_USER_DATA_DIR}/outputs", False),
|
||||
]
|
||||
|
||||
# Ensure directories exist before mounting
|
||||
for host_path, _, _ in mounts:
|
||||
os.makedirs(host_path, exist_ok=True)
|
||||
|
||||
return mounts
|
||||
|
||||
def _get_skills_mount(self) -> tuple[str, str, bool] | None:
|
||||
"""Get the skills directory mount configuration.
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
"""Create a sandbox environment and assign it to an agent.
|
||||
|
||||
Lifecycle Management:
|
||||
- Sandbox is acquired on first agent invocation for a thread (before_agent)
|
||||
- With lazy_init=True (default): Sandbox is acquired on first tool call
|
||||
- With lazy_init=False: Sandbox is acquired on first agent invocation (before_agent)
|
||||
- Sandbox is reused across multiple turns within the same thread
|
||||
- Sandbox is NOT released after each agent call to avoid wasteful recreation
|
||||
- Cleanup happens at application shutdown via SandboxProvider.shutdown()
|
||||
@@ -27,6 +28,17 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
|
||||
state_schema = SandboxMiddlewareState
|
||||
|
||||
def __init__(self, lazy_init: bool = True):
|
||||
"""Initialize sandbox middleware.
|
||||
|
||||
Args:
|
||||
lazy_init: If True, defer sandbox acquisition until first tool call.
|
||||
If False, acquire sandbox eagerly in before_agent().
|
||||
Default is True for optimal performance.
|
||||
"""
|
||||
super().__init__()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _acquire_sandbox(self, thread_id: str) -> str:
|
||||
provider = get_sandbox_provider()
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
@@ -35,6 +47,11 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
|
||||
@override
|
||||
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
# Skip acquisition if lazy_init is enabled
|
||||
if self._lazy_init:
|
||||
return super().before_agent(state, runtime)
|
||||
|
||||
# Eager initialization (original behavior)
|
||||
if "sandbox" not in state or state["sandbox"] is None:
|
||||
thread_id = runtime.context["thread_id"]
|
||||
print(f"Thread ID: {thread_id}")
|
||||
|
||||
@@ -6,8 +6,6 @@ from langgraph.typing import ContextT
|
||||
from src.agents.thread_state import ThreadDataState, ThreadState
|
||||
from src.sandbox.exceptions import (
|
||||
SandboxError,
|
||||
SandboxFileError,
|
||||
SandboxFileNotFoundError,
|
||||
SandboxNotFoundError,
|
||||
SandboxRuntimeError,
|
||||
)
|
||||
@@ -115,6 +113,9 @@ def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool
|
||||
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
"""Extract sandbox instance from tool runtime.
|
||||
|
||||
DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support.
|
||||
This function assumes sandbox is already initialized and will raise error if not.
|
||||
|
||||
Raises:
|
||||
SandboxRuntimeError: If runtime is not available or sandbox state is missing.
|
||||
SandboxNotFoundError: If sandbox with the given ID cannot be found.
|
||||
@@ -133,6 +134,57 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
|
||||
return sandbox
|
||||
|
||||
|
||||
def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||
"""Ensure sandbox is initialized, acquiring lazily if needed.
|
||||
|
||||
On first call, acquires a sandbox from the provider and stores it in runtime state.
|
||||
Subsequent calls return the existing sandbox.
|
||||
|
||||
Thread-safety is guaranteed by the provider's internal locking mechanism.
|
||||
|
||||
Args:
|
||||
runtime: Tool runtime containing state and context.
|
||||
|
||||
Returns:
|
||||
Initialized sandbox instance.
|
||||
|
||||
Raises:
|
||||
SandboxRuntimeError: If runtime is not available or thread_id is missing.
|
||||
SandboxNotFoundError: If sandbox acquisition fails.
|
||||
"""
|
||||
if runtime is None:
|
||||
raise SandboxRuntimeError("Tool runtime not available")
|
||||
|
||||
# Check if sandbox already exists in state
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
if sandbox_state is not None:
|
||||
sandbox_id = sandbox_state.get("sandbox_id")
|
||||
if sandbox_id is not None:
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
return sandbox
|
||||
# Sandbox was released, fall through to acquire new one
|
||||
|
||||
# Lazy acquisition: get thread_id and acquire sandbox
|
||||
thread_id = runtime.context.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise SandboxRuntimeError("Thread ID not available in runtime context")
|
||||
|
||||
provider = get_sandbox_provider()
|
||||
print(f"Lazy acquiring sandbox for thread {thread_id}")
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
|
||||
# Update runtime state - this persists across tool calls
|
||||
runtime.state["sandbox"] = {"sandbox_id": sandbox_id}
|
||||
|
||||
# Retrieve and return the sandbox
|
||||
sandbox = provider.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
|
||||
|
||||
return sandbox
|
||||
|
||||
|
||||
@tool("bash", parse_docstring=True)
|
||||
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||
"""Execute a bash command in a Linux environment.
|
||||
@@ -146,7 +198,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
|
||||
command: The bash command to execute. Always use absolute paths for files and directories.
|
||||
"""
|
||||
try:
|
||||
sandbox = sandbox_from_runtime(runtime)
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
command = replace_virtual_paths_in_command(command, thread_data)
|
||||
@@ -166,7 +218,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
path: The **absolute** path to the directory to list.
|
||||
"""
|
||||
try:
|
||||
sandbox = sandbox_from_runtime(runtime)
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = replace_virtual_path(path, thread_data)
|
||||
@@ -201,7 +253,7 @@ def read_file_tool(
|
||||
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
|
||||
"""
|
||||
try:
|
||||
sandbox = sandbox_from_runtime(runtime)
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = replace_virtual_path(path, thread_data)
|
||||
@@ -239,7 +291,7 @@ def write_file_tool(
|
||||
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
"""
|
||||
try:
|
||||
sandbox = sandbox_from_runtime(runtime)
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = replace_virtual_path(path, thread_data)
|
||||
@@ -277,7 +329,7 @@ def str_replace_tool(
|
||||
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
|
||||
"""
|
||||
try:
|
||||
sandbox = sandbox_from_runtime(runtime)
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
path = replace_virtual_path(path, thread_data)
|
||||
|
||||
Reference in New Issue
Block a user