feat: add thread data middleware (#2)

This commit is contained in:
DanielWalnut
2026-01-15 13:22:30 +08:00
committed by GitHub
parent ab427731dc
commit c92eedc572
8 changed files with 181 additions and 14 deletions

View File

@@ -1,12 +1,11 @@
from src.sandbox.local.local_sandbox import LocalSandbox
from src.sandbox.sandbox import Sandbox
from src.sandbox.sandbox_provider import SandboxProvider
_singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider):
def acquire(self) -> Sandbox:
def acquire(self, thread_id: str | None = None) -> str:
global _singleton
if _singleton is None:
_singleton = LocalSandbox("local")

View File

@@ -4,7 +4,7 @@ from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from src.agents.thread_state import SandboxState
from src.agents.thread_state import SandboxState, ThreadDataState
from src.sandbox import get_sandbox_provider
@@ -12,6 +12,7 @@ class SandboxMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
@@ -19,15 +20,17 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
state_schema = SandboxMiddlewareState
def _acquire_sandbox(self) -> str:
def _acquire_sandbox(self, thread_id: str) -> str:
provider = get_sandbox_provider()
sandbox_id = provider.acquire()
sandbox_id = provider.acquire(thread_id)
print(f"Acquiring sandbox {sandbox_id}")
return sandbox_id
@override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
if "sandbox" not in state or state["sandbox"] is None:
sandbox_id = self._acquire_sandbox()
thread_id = runtime.context["thread_id"]
print(f"Thread ID: {thread_id}")
sandbox_id = self._acquire_sandbox(thread_id)
return {"sandbox": {"sandbox_id": sandbox_id}}
return super().before_agent(state, runtime)

View File

@@ -9,7 +9,7 @@ class SandboxProvider(ABC):
"""Abstract base class for sandbox providers"""
@abstractmethod
def acquire(self) -> str:
def acquire(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment and return its ID.
Returns:
@@ -39,7 +39,7 @@ class SandboxProvider(ABC):
_default_sandbox_provider: SandboxProvider | None = None
def get_sandbox_provider() -> SandboxProvider:
def get_sandbox_provider(**kwargs) -> SandboxProvider:
"""Get the sandbox provider.
Returns:
@@ -49,5 +49,5 @@ def get_sandbox_provider() -> SandboxProvider:
if _default_sandbox_provider is None:
config = get_app_config()
cls = resolve_class(config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls()
_default_sandbox_provider = cls(**kwargs)
return _default_sandbox_provider