fix: fix local path for local sandbox (#3)

This commit is contained in:
DanielWalnut
2026-01-15 14:37:00 +08:00
committed by GitHub
parent 41442ccc2f
commit 3b879e277e
3 changed files with 123 additions and 10 deletions

View File

@@ -1,10 +1,109 @@
import re
from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from src.agents.thread_state import ThreadState
from src.agents.thread_state import ThreadDataState, ThreadState
from src.sandbox.sandbox import Sandbox
from src.sandbox.sandbox_provider import get_sandbox_provider
# Virtual path prefix used in sandbox environments
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
"""Replace virtual /mnt/user-data paths with actual thread data paths.
Mapping:
/mnt/user-data/workspace/* -> thread_data['workspace_path']/*
/mnt/user-data/uploads/* -> thread_data['uploads_path']/*
/mnt/user-data/outputs/* -> thread_data['outputs_path']/*
Args:
path: The path that may contain virtual path prefix.
thread_data: The thread data containing actual paths.
Returns:
The path with virtual prefix replaced by actual path.
"""
if not path.startswith(VIRTUAL_PATH_PREFIX):
return path
if thread_data is None:
return path
# Map virtual subdirectories to thread_data keys
path_mapping = {
"workspace": thread_data.get("workspace_path"),
"uploads": thread_data.get("uploads_path"),
"outputs": thread_data.get("outputs_path"),
}
# Extract the subdirectory after /mnt/user-data/
relative_path = path[len(VIRTUAL_PATH_PREFIX) :].lstrip("/")
if not relative_path:
return path
# Find which subdirectory this path belongs to
parts = relative_path.split("/", 1)
subdir = parts[0]
rest = parts[1] if len(parts) > 1 else ""
actual_base = path_mapping.get(subdir)
if actual_base is None:
return path
if rest:
return f"{actual_base}/{rest}"
return actual_base
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None) -> str:
"""Replace all virtual /mnt/user-data paths in a command string.
Args:
command: The command string that may contain virtual paths.
thread_data: The thread data containing actual paths.
Returns:
The command with all virtual paths replaced.
"""
if VIRTUAL_PATH_PREFIX not in command:
return command
if thread_data is None:
return command
# Pattern to match /mnt/user-data followed by path characters
pattern = re.compile(rf"{re.escape(VIRTUAL_PATH_PREFIX)}(/[^\s\"';&|<>()]*)?")
def replace_match(match: re.Match) -> str:
full_path = match.group(0)
return replace_virtual_path(full_path, thread_data)
return pattern.sub(replace_match, command)
def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None:
"""Extract thread_data from runtime state."""
if runtime is None:
return None
return runtime.state.get("thread_data")
def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool:
"""Check if the current sandbox is a local sandbox.
Path replacement is only needed for local sandbox since aio sandbox
already has /mnt/user-data mounted in the container.
"""
if runtime is None:
return False
sandbox_state = runtime.state.get("sandbox")
if sandbox_state is None:
return False
return sandbox_state.get("sandbox_id") == "local"
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
if runtime is None:
@@ -35,6 +134,9 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
"""
try:
sandbox = sandbox_from_runtime(runtime)
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
command = replace_virtual_paths_in_command(command, thread_data)
return sandbox.execute_command(command)
except Exception as e:
return f"Error: {e}"
@@ -50,6 +152,9 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
"""
try:
sandbox = sandbox_from_runtime(runtime)
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
path = replace_virtual_path(path, thread_data)
children = sandbox.list_dir(path)
if not children:
return "(empty)"
@@ -74,6 +179,9 @@ def read_file_tool(
"""
try:
sandbox = sandbox_from_runtime(runtime)
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
path = replace_virtual_path(path, thread_data)
content = sandbox.read_file(path)
if not content:
return "(empty)"
@@ -102,6 +210,9 @@ def write_file_tool(
"""
try:
sandbox = sandbox_from_runtime(runtime)
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
path = replace_virtual_path(path, thread_data)
sandbox.write_file(path, content, append)
return "OK"
except Exception as e:
@@ -129,6 +240,9 @@ def str_replace_tool(
"""
try:
sandbox = sandbox_from_runtime(runtime)
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
path = replace_virtual_path(path, thread_data)
content = sandbox.read_file(path)
if not content:
return "OK"