mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
|
|
from typing import override
|
||
|
|
|
||
|
|
from langchain.agents import AgentState
|
||
|
|
from langchain.agents.middleware import AgentMiddleware
|
||
|
|
from langgraph.runtime import Runtime
|
||
|
|
|
||
|
|
from src.agents.thread_state import SandboxState
|
||
|
|
from src.sandbox import get_sandbox_provider
|
||
|
|
|
||
|
|
|
||
|
|
class SandboxMiddlewareState(AgentState):
|
||
|
|
"""Compatible with the `ThreadState` schema."""
|
||
|
|
|
||
|
|
sandbox: SandboxState | None = None
|
||
|
|
|
||
|
|
|
||
|
|
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||
|
|
"""Create a sandbox environment and assign it to an agent."""
|
||
|
|
|
||
|
|
state_schema = SandboxMiddlewareState
|
||
|
|
|
||
|
|
def _acquire_sandbox(self) -> str:
|
||
|
|
provider = get_sandbox_provider()
|
||
|
|
sandbox_id = provider.acquire()
|
||
|
|
print(f"Acquiring sandbox {sandbox_id}")
|
||
|
|
return sandbox_id
|
||
|
|
|
||
|
|
@override
|
||
|
|
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||
|
|
if "sandbox" not in state or state["sandbox"] is None:
|
||
|
|
sandbox_id = self._acquire_sandbox()
|
||
|
|
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||
|
|
return super().before_agent(state, runtime)
|