feat: add thread-safety and graceful shutdown to AioSandboxProvider (#7)

Add thread-safe port allocation and proper cleanup on process exit to
prevent port conflicts in concurrent environments and ensure containers
are stopped when the application terminates.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut
2026-01-16 22:28:19 +08:00
committed by GitHub
parent c0e63c5308
commit 4b69aed47b
2 changed files with 238 additions and 38 deletions

View File

@@ -1,6 +1,9 @@
import atexit
import logging
import os
import signal
import subprocess
import threading
import time
import uuid
from pathlib import Path
@@ -10,6 +13,7 @@ import requests
from src.config import get_app_config
from src.sandbox.sandbox import Sandbox
from src.sandbox.sandbox_provider import SandboxProvider
from src.utils.network import get_free_port, release_port
from .aio_sandbox import AioSandbox
@@ -42,9 +46,39 @@ class AioSandboxProvider(SandboxProvider):
"""
def __init__(self):
self._lock = threading.Lock()
self._sandboxes: dict[str, AioSandbox] = {}
self._containers: dict[str, str] = {} # sandbox_id -> container_id
self._ports: dict[str, int] = {} # sandbox_id -> port
self._config = self._load_config()
self._shutdown_called = False
# Register shutdown handler to clean up containers on exit
atexit.register(self.shutdown)
self._register_signal_handlers()
def _register_signal_handlers(self) -> None:
"""Register signal handlers for graceful shutdown."""
self._original_sigterm = signal.getsignal(signal.SIGTERM)
self._original_sigint = signal.getsignal(signal.SIGINT)
def signal_handler(signum, frame):
self.shutdown()
# Call original handler
original = self._original_sigterm if signum == signal.SIGTERM else self._original_sigint
if callable(original):
original(signum, frame)
elif original == signal.SIG_DFL:
# Re-raise the signal with default handler
signal.signal(signum, signal.SIG_DFL)
signal.raise_signal(signum)
try:
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
except ValueError:
# Signal handling can only be set from the main thread
logger.debug("Could not register signal handlers (not main thread)")
def _load_config(self) -> dict:
"""Load sandbox configuration from app config."""
@@ -190,33 +224,14 @@ class AioSandboxProvider(SandboxProvider):
except subprocess.CalledProcessError as e:
logger.warning(f"Failed to stop sandbox container {container_id}: {e.stderr}")
def _find_available_port(self, start_port: int) -> int:
"""Find an available port starting from start_port.
Args:
start_port: Port to start searching from.
Returns:
An available port number.
"""
import socket
port = start_port
while port < start_port + 100: # Search up to 100 ports
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("localhost", port))
return port
except OSError:
port += 1
raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}")
def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID.
If base_url is configured, uses the existing sandbox.
Otherwise, starts a new Docker container.
This method is thread-safe.
Args:
thread_id: Optional thread ID for thread-specific configurations.
If provided, the sandbox will be configured with thread-specific
@@ -244,60 +259,110 @@ class AioSandboxProvider(SandboxProvider):
base_url = self._config["base_url"]
logger.info(f"Using existing sandbox at {base_url}")
if not self._is_sandbox_ready(base_url, timeout=5):
if not self._is_sandbox_ready(base_url, timeout=60):
raise RuntimeError(f"Sandbox at {base_url} is not ready")
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
self._sandboxes[sandbox_id] = sandbox
with self._lock:
self._sandboxes[sandbox_id] = sandbox
return sandbox_id
# Otherwise, start a new container
if not self._config.get("auto_start", True):
raise RuntimeError("auto_start is disabled and no base_url is configured")
port = self._find_available_port(self._config["port"])
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts if extra_mounts else None)
self._containers[sandbox_id] = container_id
# Allocate port using thread-safe utility
port = get_free_port(start_port=self._config["port"])
try:
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts if extra_mounts else None)
except Exception:
# Release port if container failed to start
release_port(port)
raise
base_url = f"http://localhost:{port}"
# Wait for sandbox to be ready
if not self._is_sandbox_ready(base_url, timeout=60):
# Clean up container if it didn't start properly
# Clean up container and release port if it didn't start properly
self._stop_container(container_id)
del self._containers[sandbox_id]
release_port(port)
raise RuntimeError("Sandbox container failed to start within timeout")
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
self._sandboxes[sandbox_id] = sandbox
with self._lock:
self._sandboxes[sandbox_id] = sandbox
self._containers[sandbox_id] = container_id
self._ports[sandbox_id] = port
logger.info(f"Acquired sandbox {sandbox_id} at {base_url}")
return sandbox_id
def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox environment by ID.
This method is thread-safe.
Args:
sandbox_id: The ID of the sandbox environment.
Returns:
The sandbox instance if found, None otherwise.
"""
return self._sandboxes.get(sandbox_id)
with self._lock:
return self._sandboxes.get(sandbox_id)
def release(self, sandbox_id: str) -> None:
"""Release a sandbox environment.
If the sandbox was started by this provider, stops the container.
If the sandbox was started by this provider, stops the container
and releases the allocated port.
This method is thread-safe.
Args:
sandbox_id: The ID of the sandbox environment to release.
"""
if sandbox_id in self._sandboxes:
del self._sandboxes[sandbox_id]
logger.info(f"Released sandbox {sandbox_id}")
container_id = None
port = None
# Stop container if we started it
if sandbox_id in self._containers:
container_id = self._containers[sandbox_id]
with self._lock:
if sandbox_id in self._sandboxes:
del self._sandboxes[sandbox_id]
logger.info(f"Released sandbox {sandbox_id}")
# Get container and port info while holding the lock
if sandbox_id in self._containers:
container_id = self._containers.pop(sandbox_id)
if sandbox_id in self._ports:
port = self._ports.pop(sandbox_id)
# Stop container and release port outside the lock to avoid blocking
if container_id:
self._stop_container(container_id)
del self._containers[sandbox_id]
if port:
release_port(port)
def shutdown(self) -> None:
"""Shutdown all sandbox containers managed by this provider.
This method should be called when the application is shutting down
to ensure all containers are properly stopped and ports are released.
This method is thread-safe and idempotent (safe to call multiple times).
"""
# Prevent multiple shutdown calls
with self._lock:
if self._shutdown_called:
return
self._shutdown_called = True
sandbox_ids = list(self._sandboxes.keys())
logger.info(f"Shutting down {len(sandbox_ids)} sandbox container(s)")
for sandbox_id in sandbox_ids:
try:
self.release(sandbox_id)
except Exception as e:
logger.error(f"Failed to release sandbox {sandbox_id} during shutdown: {e}")

View File

@@ -0,0 +1,135 @@
"""Thread-safe network utilities."""
import socket
import threading
from contextlib import contextmanager
class PortAllocator:
"""Thread-safe port allocator that prevents port conflicts in concurrent environments.
This class maintains a set of reserved ports and uses a lock to ensure that
port allocation is atomic. Once a port is allocated, it remains reserved until
explicitly released.
Usage:
allocator = PortAllocator()
# Option 1: Manual allocation and release
port = allocator.allocate(start_port=8080)
try:
# Use the port...
finally:
allocator.release(port)
# Option 2: Context manager (recommended)
with allocator.allocate_context(start_port=8080) as port:
# Use the port...
# Port is automatically released when exiting the context
"""
def __init__(self):
self._lock = threading.Lock()
self._reserved_ports: set[int] = set()
def _is_port_available(self, port: int) -> bool:
"""Check if a port is available for binding.
Args:
port: The port number to check.
Returns:
True if the port is available, False otherwise.
"""
if port in self._reserved_ports:
return False
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(("localhost", port))
return True
except OSError:
return False
def allocate(self, start_port: int = 8080, max_range: int = 100) -> int:
"""Allocate an available port in a thread-safe manner.
This method is thread-safe. It finds an available port, marks it as reserved,
and returns it. The port remains reserved until release() is called.
Args:
start_port: The port number to start searching from.
max_range: Maximum number of ports to search.
Returns:
An available port number.
Raises:
RuntimeError: If no available port is found in the specified range.
"""
with self._lock:
for port in range(start_port, start_port + max_range):
if self._is_port_available(port):
self._reserved_ports.add(port)
return port
raise RuntimeError(f"No available port found in range {start_port}-{start_port + max_range}")
def release(self, port: int) -> None:
"""Release a previously allocated port.
Args:
port: The port number to release.
"""
with self._lock:
self._reserved_ports.discard(port)
@contextmanager
def allocate_context(self, start_port: int = 8080, max_range: int = 100):
"""Context manager for port allocation with automatic release.
Args:
start_port: The port number to start searching from.
max_range: Maximum number of ports to search.
Yields:
An available port number.
"""
port = self.allocate(start_port, max_range)
try:
yield port
finally:
self.release(port)
# Global port allocator instance for shared use across the application
_global_port_allocator = PortAllocator()
def get_free_port(start_port: int = 8080, max_range: int = 100) -> int:
"""Get a free port in a thread-safe manner.
This function uses a global port allocator to ensure that concurrent calls
don't return the same port. The port is marked as reserved until release_port()
is called.
Args:
start_port: The port number to start searching from.
max_range: Maximum number of ports to search.
Returns:
An available port number.
Raises:
RuntimeError: If no available port is found in the specified range.
"""
return _global_port_allocator.allocate(start_port, max_range)
def release_port(port: int) -> None:
"""Release a previously allocated port.
Args:
port: The port number to release.
"""
_global_port_allocator.release(port)