mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 14:22:13 +08:00
feat: add thread-safety and graceful shutdown to AioSandboxProvider (#7)
Add thread-safe port allocation and proper cleanup on process exit to prevent port conflicts in concurrent environments and ensure containers are stopped when the application terminates. Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import atexit
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -10,6 +13,7 @@ import requests
|
||||
from src.config import get_app_config
|
||||
from src.sandbox.sandbox import Sandbox
|
||||
from src.sandbox.sandbox_provider import SandboxProvider
|
||||
from src.utils.network import get_free_port, release_port
|
||||
|
||||
from .aio_sandbox import AioSandbox
|
||||
|
||||
@@ -42,9 +46,39 @@ class AioSandboxProvider(SandboxProvider):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._sandboxes: dict[str, AioSandbox] = {}
|
||||
self._containers: dict[str, str] = {} # sandbox_id -> container_id
|
||||
self._ports: dict[str, int] = {} # sandbox_id -> port
|
||||
self._config = self._load_config()
|
||||
self._shutdown_called = False
|
||||
|
||||
# Register shutdown handler to clean up containers on exit
|
||||
atexit.register(self.shutdown)
|
||||
self._register_signal_handlers()
|
||||
|
||||
def _register_signal_handlers(self) -> None:
|
||||
"""Register signal handlers for graceful shutdown."""
|
||||
self._original_sigterm = signal.getsignal(signal.SIGTERM)
|
||||
self._original_sigint = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
self.shutdown()
|
||||
# Call original handler
|
||||
original = self._original_sigterm if signum == signal.SIGTERM else self._original_sigint
|
||||
if callable(original):
|
||||
original(signum, frame)
|
||||
elif original == signal.SIG_DFL:
|
||||
# Re-raise the signal with default handler
|
||||
signal.signal(signum, signal.SIG_DFL)
|
||||
signal.raise_signal(signum)
|
||||
|
||||
try:
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
except ValueError:
|
||||
# Signal handling can only be set from the main thread
|
||||
logger.debug("Could not register signal handlers (not main thread)")
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""Load sandbox configuration from app config."""
|
||||
@@ -190,33 +224,14 @@ class AioSandboxProvider(SandboxProvider):
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Failed to stop sandbox container {container_id}: {e.stderr}")
|
||||
|
||||
def _find_available_port(self, start_port: int) -> int:
|
||||
"""Find an available port starting from start_port.
|
||||
|
||||
Args:
|
||||
start_port: Port to start searching from.
|
||||
|
||||
Returns:
|
||||
An available port number.
|
||||
"""
|
||||
import socket
|
||||
|
||||
port = start_port
|
||||
while port < start_port + 100: # Search up to 100 ports
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("localhost", port))
|
||||
return port
|
||||
except OSError:
|
||||
port += 1
|
||||
raise RuntimeError(f"No available port found in range {start_port}-{start_port + 100}")
|
||||
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
If base_url is configured, uses the existing sandbox.
|
||||
Otherwise, starts a new Docker container.
|
||||
|
||||
This method is thread-safe.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID for thread-specific configurations.
|
||||
If provided, the sandbox will be configured with thread-specific
|
||||
@@ -244,60 +259,110 @@ class AioSandboxProvider(SandboxProvider):
|
||||
base_url = self._config["base_url"]
|
||||
logger.info(f"Using existing sandbox at {base_url}")
|
||||
|
||||
if not self._is_sandbox_ready(base_url, timeout=5):
|
||||
if not self._is_sandbox_ready(base_url, timeout=60):
|
||||
raise RuntimeError(f"Sandbox at {base_url} is not ready")
|
||||
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
with self._lock:
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
return sandbox_id
|
||||
|
||||
# Otherwise, start a new container
|
||||
if not self._config.get("auto_start", True):
|
||||
raise RuntimeError("auto_start is disabled and no base_url is configured")
|
||||
|
||||
port = self._find_available_port(self._config["port"])
|
||||
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts if extra_mounts else None)
|
||||
self._containers[sandbox_id] = container_id
|
||||
# Allocate port using thread-safe utility
|
||||
port = get_free_port(start_port=self._config["port"])
|
||||
try:
|
||||
container_id = self._start_container(sandbox_id, port, extra_mounts=extra_mounts if extra_mounts else None)
|
||||
except Exception:
|
||||
# Release port if container failed to start
|
||||
release_port(port)
|
||||
raise
|
||||
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
# Wait for sandbox to be ready
|
||||
if not self._is_sandbox_ready(base_url, timeout=60):
|
||||
# Clean up container if it didn't start properly
|
||||
# Clean up container and release port if it didn't start properly
|
||||
self._stop_container(container_id)
|
||||
del self._containers[sandbox_id]
|
||||
release_port(port)
|
||||
raise RuntimeError("Sandbox container failed to start within timeout")
|
||||
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
with self._lock:
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._containers[sandbox_id] = container_id
|
||||
self._ports[sandbox_id] = port
|
||||
logger.info(f"Acquired sandbox {sandbox_id} at {base_url}")
|
||||
return sandbox_id
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox environment by ID.
|
||||
|
||||
This method is thread-safe.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox environment.
|
||||
|
||||
Returns:
|
||||
The sandbox instance if found, None otherwise.
|
||||
"""
|
||||
return self._sandboxes.get(sandbox_id)
|
||||
with self._lock:
|
||||
return self._sandboxes.get(sandbox_id)
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
"""Release a sandbox environment.
|
||||
|
||||
If the sandbox was started by this provider, stops the container.
|
||||
If the sandbox was started by this provider, stops the container
|
||||
and releases the allocated port.
|
||||
|
||||
This method is thread-safe.
|
||||
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox environment to release.
|
||||
"""
|
||||
if sandbox_id in self._sandboxes:
|
||||
del self._sandboxes[sandbox_id]
|
||||
logger.info(f"Released sandbox {sandbox_id}")
|
||||
container_id = None
|
||||
port = None
|
||||
|
||||
# Stop container if we started it
|
||||
if sandbox_id in self._containers:
|
||||
container_id = self._containers[sandbox_id]
|
||||
with self._lock:
|
||||
if sandbox_id in self._sandboxes:
|
||||
del self._sandboxes[sandbox_id]
|
||||
logger.info(f"Released sandbox {sandbox_id}")
|
||||
|
||||
# Get container and port info while holding the lock
|
||||
if sandbox_id in self._containers:
|
||||
container_id = self._containers.pop(sandbox_id)
|
||||
|
||||
if sandbox_id in self._ports:
|
||||
port = self._ports.pop(sandbox_id)
|
||||
|
||||
# Stop container and release port outside the lock to avoid blocking
|
||||
if container_id:
|
||||
self._stop_container(container_id)
|
||||
del self._containers[sandbox_id]
|
||||
|
||||
if port:
|
||||
release_port(port)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown all sandbox containers managed by this provider.
|
||||
|
||||
This method should be called when the application is shutting down
|
||||
to ensure all containers are properly stopped and ports are released.
|
||||
|
||||
This method is thread-safe and idempotent (safe to call multiple times).
|
||||
"""
|
||||
# Prevent multiple shutdown calls
|
||||
with self._lock:
|
||||
if self._shutdown_called:
|
||||
return
|
||||
self._shutdown_called = True
|
||||
sandbox_ids = list(self._sandboxes.keys())
|
||||
|
||||
logger.info(f"Shutting down {len(sandbox_ids)} sandbox container(s)")
|
||||
|
||||
for sandbox_id in sandbox_ids:
|
||||
try:
|
||||
self.release(sandbox_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release sandbox {sandbox_id} during shutdown: {e}")
|
||||
|
||||
135
backend/src/utils/network.py
Normal file
135
backend/src/utils/network.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Thread-safe network utilities."""
|
||||
|
||||
import socket
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class PortAllocator:
|
||||
"""Thread-safe port allocator that prevents port conflicts in concurrent environments.
|
||||
|
||||
This class maintains a set of reserved ports and uses a lock to ensure that
|
||||
port allocation is atomic. Once a port is allocated, it remains reserved until
|
||||
explicitly released.
|
||||
|
||||
Usage:
|
||||
allocator = PortAllocator()
|
||||
|
||||
# Option 1: Manual allocation and release
|
||||
port = allocator.allocate(start_port=8080)
|
||||
try:
|
||||
# Use the port...
|
||||
finally:
|
||||
allocator.release(port)
|
||||
|
||||
# Option 2: Context manager (recommended)
|
||||
with allocator.allocate_context(start_port=8080) as port:
|
||||
# Use the port...
|
||||
# Port is automatically released when exiting the context
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._reserved_ports: set[int] = set()
|
||||
|
||||
def _is_port_available(self, port: int) -> bool:
|
||||
"""Check if a port is available for binding.
|
||||
|
||||
Args:
|
||||
port: The port number to check.
|
||||
|
||||
Returns:
|
||||
True if the port is available, False otherwise.
|
||||
"""
|
||||
if port in self._reserved_ports:
|
||||
return False
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("localhost", port))
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
def allocate(self, start_port: int = 8080, max_range: int = 100) -> int:
|
||||
"""Allocate an available port in a thread-safe manner.
|
||||
|
||||
This method is thread-safe. It finds an available port, marks it as reserved,
|
||||
and returns it. The port remains reserved until release() is called.
|
||||
|
||||
Args:
|
||||
start_port: The port number to start searching from.
|
||||
max_range: Maximum number of ports to search.
|
||||
|
||||
Returns:
|
||||
An available port number.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no available port is found in the specified range.
|
||||
"""
|
||||
with self._lock:
|
||||
for port in range(start_port, start_port + max_range):
|
||||
if self._is_port_available(port):
|
||||
self._reserved_ports.add(port)
|
||||
return port
|
||||
|
||||
raise RuntimeError(f"No available port found in range {start_port}-{start_port + max_range}")
|
||||
|
||||
def release(self, port: int) -> None:
|
||||
"""Release a previously allocated port.
|
||||
|
||||
Args:
|
||||
port: The port number to release.
|
||||
"""
|
||||
with self._lock:
|
||||
self._reserved_ports.discard(port)
|
||||
|
||||
@contextmanager
|
||||
def allocate_context(self, start_port: int = 8080, max_range: int = 100):
|
||||
"""Context manager for port allocation with automatic release.
|
||||
|
||||
Args:
|
||||
start_port: The port number to start searching from.
|
||||
max_range: Maximum number of ports to search.
|
||||
|
||||
Yields:
|
||||
An available port number.
|
||||
"""
|
||||
port = self.allocate(start_port, max_range)
|
||||
try:
|
||||
yield port
|
||||
finally:
|
||||
self.release(port)
|
||||
|
||||
|
||||
# Global port allocator instance for shared use across the application
|
||||
_global_port_allocator = PortAllocator()
|
||||
|
||||
|
||||
def get_free_port(start_port: int = 8080, max_range: int = 100) -> int:
|
||||
"""Get a free port in a thread-safe manner.
|
||||
|
||||
This function uses a global port allocator to ensure that concurrent calls
|
||||
don't return the same port. The port is marked as reserved until release_port()
|
||||
is called.
|
||||
|
||||
Args:
|
||||
start_port: The port number to start searching from.
|
||||
max_range: Maximum number of ports to search.
|
||||
|
||||
Returns:
|
||||
An available port number.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no available port is found in the specified range.
|
||||
"""
|
||||
return _global_port_allocator.allocate(start_port, max_range)
|
||||
|
||||
|
||||
def release_port(port: int) -> None:
|
||||
"""Release a previously allocated port.
|
||||
|
||||
Args:
|
||||
port: The port number to release.
|
||||
"""
|
||||
_global_port_allocator.release(port)
|
||||
Reference in New Issue
Block a user