mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-18 03:54:46 +08:00
feat: add thread data middleware (#2)
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
from src.sandbox.local.local_sandbox import LocalSandbox
|
||||
from src.sandbox.sandbox import Sandbox
|
||||
from src.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
_singleton: LocalSandbox | None = None
|
||||
|
||||
|
||||
class LocalSandboxProvider(SandboxProvider):
|
||||
def acquire(self) -> Sandbox:
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
global _singleton
|
||||
if _singleton is None:
|
||||
_singleton = LocalSandbox("local")
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from src.agents.thread_state import SandboxState
|
||||
from src.agents.thread_state import SandboxState, ThreadDataState
|
||||
from src.sandbox import get_sandbox_provider
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ class SandboxMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
|
||||
|
||||
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
@@ -19,15 +20,17 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
|
||||
state_schema = SandboxMiddlewareState
|
||||
|
||||
def _acquire_sandbox(self) -> str:
|
||||
def _acquire_sandbox(self, thread_id: str) -> str:
|
||||
provider = get_sandbox_provider()
|
||||
sandbox_id = provider.acquire()
|
||||
sandbox_id = provider.acquire(thread_id)
|
||||
print(f"Acquiring sandbox {sandbox_id}")
|
||||
return sandbox_id
|
||||
|
||||
@override
|
||||
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
if "sandbox" not in state or state["sandbox"] is None:
|
||||
sandbox_id = self._acquire_sandbox()
|
||||
thread_id = runtime.context["thread_id"]
|
||||
print(f"Thread ID: {thread_id}")
|
||||
sandbox_id = self._acquire_sandbox(thread_id)
|
||||
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||
return super().before_agent(state, runtime)
|
||||
|
||||
@@ -9,7 +9,7 @@ class SandboxProvider(ABC):
|
||||
"""Abstract base class for sandbox providers"""
|
||||
|
||||
@abstractmethod
|
||||
def acquire(self) -> str:
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment and return its ID.
|
||||
|
||||
Returns:
|
||||
@@ -39,7 +39,7 @@ class SandboxProvider(ABC):
|
||||
_default_sandbox_provider: SandboxProvider | None = None
|
||||
|
||||
|
||||
def get_sandbox_provider() -> SandboxProvider:
|
||||
def get_sandbox_provider(**kwargs) -> SandboxProvider:
|
||||
"""Get the sandbox provider.
|
||||
|
||||
Returns:
|
||||
@@ -49,5 +49,5 @@ def get_sandbox_provider() -> SandboxProvider:
|
||||
if _default_sandbox_provider is None:
|
||||
config = get_app_config()
|
||||
cls = resolve_class(config.sandbox.use, SandboxProvider)
|
||||
_default_sandbox_provider = cls()
|
||||
_default_sandbox_provider = cls(**kwargs)
|
||||
return _default_sandbox_provider
|
||||
|
||||
Reference in New Issue
Block a user