from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime from src.agents.thread_state import SandboxState, ThreadDataState from src.sandbox import get_sandbox_provider class SandboxMiddlewareState(AgentState): """Compatible with the `ThreadState` schema.""" sandbox: NotRequired[SandboxState | None] thread_data: NotRequired[ThreadDataState | None] class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): """Create a sandbox environment and assign it to an agent. Lifecycle Management: - Sandbox is acquired on first agent invocation for a thread (before_agent) - Sandbox is reused across multiple turns within the same thread - Sandbox is NOT released after each agent call to avoid wasteful recreation - Cleanup happens at application shutdown via SandboxProvider.shutdown() """ state_schema = SandboxMiddlewareState def _acquire_sandbox(self, thread_id: str) -> str: provider = get_sandbox_provider() sandbox_id = provider.acquire(thread_id) print(f"Acquiring sandbox {sandbox_id}") return sandbox_id @override def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: if "sandbox" not in state or state["sandbox"] is None: thread_id = runtime.context["thread_id"] print(f"Thread ID: {thread_id}") sandbox_id = self._acquire_sandbox(thread_id) return {"sandbox": {"sandbox_id": sandbox_id}} return super().before_agent(state, runtime)