From de2d18561adcb2bbab94246815bdca4dd151ebf8 Mon Sep 17 00:00:00 2001 From: Henry Li Date: Wed, 14 Jan 2026 12:32:34 +0800 Subject: [PATCH] feat: integrated with sandbox --- backend/src/agents/__init__.py | 3 +- backend/src/agents/lead_agent/agent.py | 4 ++ backend/src/agents/thread_state.py | 11 ++++ backend/src/config/app_config.py | 11 ++-- .../sandbox/local/local_sandbox_provider.py | 8 ++- backend/src/sandbox/middleware.py | 33 ++++++++++ backend/src/sandbox/sandbox_provider.py | 4 +- backend/src/sandbox/tools.py | 63 ++++++++++++------- 8 files changed, 103 insertions(+), 34 deletions(-) create mode 100644 backend/src/agents/thread_state.py create mode 100644 backend/src/sandbox/middleware.py diff --git a/backend/src/agents/__init__.py b/backend/src/agents/__init__.py index 63e383d..cffde06 100644 --- a/backend/src/agents/__init__.py +++ b/backend/src/agents/__init__.py @@ -1,3 +1,4 @@ from .lead_agent import lead_agent +from .thread_state import SandboxState, ThreadState -__all__ = ["lead_agent"] +__all__ = ["lead_agent", "SandboxState", "ThreadState"] diff --git a/backend/src/agents/lead_agent/agent.py b/backend/src/agents/lead_agent/agent.py index 7ea415b..9c57cef 100644 --- a/backend/src/agents/lead_agent/agent.py +++ b/backend/src/agents/lead_agent/agent.py @@ -1,11 +1,15 @@ from langchain.agents import create_agent from src.agents.lead_agent.prompt import apply_prompt_template +from src.agents.thread_state import ThreadState from src.models import create_chat_model +from src.sandbox.middleware import SandboxMiddleware from src.tools import get_available_tools lead_agent = create_agent( model=create_chat_model(thinking_enabled=True), tools=get_available_tools(), + middleware=[SandboxMiddleware()], system_prompt=apply_prompt_template(), + state_schema=ThreadState, ) diff --git a/backend/src/agents/thread_state.py b/backend/src/agents/thread_state.py new file mode 100644 index 0000000..081df2d --- /dev/null +++ b/backend/src/agents/thread_state.py @@ -0,0 +1,11 @@ +from typing import TypedDict + +from langchain.agents import AgentState + + +class SandboxState(TypedDict): + sandbox_id: str | None = None + + +class ThreadState(AgentState): + sandbox: SandboxState | None = None diff --git a/backend/src/config/app_config.py b/backend/src/config/app_config.py index a23d9fa..bfb85c3 100644 --- a/backend/src/config/app_config.py +++ b/backend/src/config/app_config.py @@ -26,7 +26,7 @@ class AppConfig(BaseModel): Priority: 1. If provided `config_path` argument, use it. 2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it. - 3. Otherwise, first check the `config.yaml` in the current directory, then use `config.yaml` in the parent directory. + 3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory. """ if config_path: path = Path(config_path) @@ -39,10 +39,13 @@ class AppConfig(BaseModel): raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}") return path else: - # Check if the config.yaml is in the parent directory of CWD - path = Path(os.getcwd()).parent / "config.yaml" + # Check if the config.yaml is in the current directory + path = Path(os.getcwd()) / "config.yaml" if not path.exists(): - raise FileNotFoundError(f"Config file not found at {path}") + # Check if the config.yaml is in the parent directory of CWD + path = Path(os.getcwd()).parent / "config.yaml" + if not path.exists(): + raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory") return path @classmethod diff --git a/backend/src/sandbox/local/local_sandbox_provider.py b/backend/src/sandbox/local/local_sandbox_provider.py index 70296fc..467e8f0 100644 --- a/backend/src/sandbox/local/local_sandbox_provider.py +++ b/backend/src/sandbox/local/local_sandbox_provider.py @@ -13,9 +13,11 @@ class LocalSandboxProvider(SandboxProvider): return _singleton.id def get(self, sandbox_id: str) -> None: - if _singleton is None: - self.acquire() - return _singleton + if sandbox_id == "local": + if _singleton is None: + self.acquire() + return _singleton + return None def release(self, sandbox_id: str) -> None: pass diff --git a/backend/src/sandbox/middleware.py b/backend/src/sandbox/middleware.py new file mode 100644 index 0000000..f536031 --- /dev/null +++ b/backend/src/sandbox/middleware.py @@ -0,0 +1,33 @@ +from typing import override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langgraph.runtime import Runtime + +from src.agents.thread_state import SandboxState +from src.sandbox import get_sandbox_provider + + +class SandboxMiddlewareState(AgentState): + """Compatible with the `ThreadState` schema.""" + + sandbox: SandboxState | None = None + + +class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): + """Create a sandbox environment and assign it to an agent.""" + + state_schema = SandboxMiddlewareState + + def _acquire_sandbox(self) -> str: + provider = get_sandbox_provider() + sandbox_id = provider.acquire() + print(f"Acquiring sandbox {sandbox_id}") + return sandbox_id + + @override + def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: + if "sandbox" not in state or state["sandbox"] is None: + sandbox_id = self._acquire_sandbox() + return {"sandbox": {"sandbox_id": sandbox_id}} + return super().before_agent(state, runtime) diff --git a/backend/src/sandbox/sandbox_provider.py b/backend/src/sandbox/sandbox_provider.py index 2d03c6d..5ca8c81 100644 --- a/backend/src/sandbox/sandbox_provider.py +++ b/backend/src/sandbox/sandbox_provider.py @@ -10,7 +10,7 @@ class SandboxProvider(ABC): @abstractmethod def acquire(self) -> str: - """Acquire a sandbox environment. + """Acquire a sandbox environment and return its ID. Returns: The ID of the acquired sandbox environment. @@ -18,7 +18,7 @@ class SandboxProvider(ABC): pass @abstractmethod - def get(self, sandbox_id: str) -> Sandbox: + def get(self, sandbox_id: str) -> Sandbox | None: """Get a sandbox environment by ID. Args: diff --git a/backend/src/sandbox/tools.py b/backend/src/sandbox/tools.py index 50c1ac3..fa77f2b 100644 --- a/backend/src/sandbox/tools.py +++ b/backend/src/sandbox/tools.py @@ -1,10 +1,28 @@ -from langchain.tools import tool +from langchain.tools import ToolRuntime, tool +from langgraph.typing import ContextT +from src.agents.thread_state import ThreadState +from src.sandbox.sandbox import Sandbox from src.sandbox.sandbox_provider import get_sandbox_provider +def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox: + if runtime is None: + raise ValueError("No sandbox found: No runtime found") + sandbox_state = runtime.state.get("sandbox") + if sandbox_state is None: + raise ValueError("No sandbox found: No sandbox state found in runtime") + sandbox_id = sandbox_state.get("sandbox_id") + if sandbox_id is None: + raise ValueError("No sandbox ID found: No sandbox ID found in sandbox state") + sandbox = get_sandbox_provider().get(sandbox_id) + if sandbox is None: + raise ValueError(f"No sandbox found: sandbox with ID {sandbox_id} not found") + return sandbox + + @tool("bash", parse_docstring=True) -def bash_tool(description: str, command: str) -> str: +def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str: """Execute a bash command in a Linux environment. @@ -12,29 +30,26 @@ def bash_tool(description: str, command: str) -> str: - Use `pip install` to install Python packages. Args: - description: Explain why you are running this command in short words. + description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. command: The bash command to execute. Always use absolute paths for files and directories. """ - # TODO: get sandbox ID from LangGraph's context - sandbox_id = "local" - sandbox = get_sandbox_provider().get(sandbox_id) try: + sandbox = sandbox_from_runtime(runtime) return sandbox.execute_command(command) except Exception as e: return f"Error: {e}" @tool("ls", parse_docstring=True) -def ls_tool(description: str, path: str) -> str: +def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str: """List the contents of a directory up to 2 levels deep in tree format. Args: - description: Explain why you are listing this directory in short words. + description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. path: The **absolute** path to the directory to list. """ try: - # TODO: get sandbox ID from LangGraph's context - sandbox = get_sandbox_provider().get("local") + sandbox = sandbox_from_runtime(runtime) children = sandbox.list_dir(path) if not children: return "(empty)" @@ -45,6 +60,7 @@ def ls_tool(description: str, path: str) -> str: @tool("read_file", parse_docstring=True) def read_file_tool( + runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str, view_range: tuple[int, int] | None = None, @@ -52,13 +68,12 @@ def read_file_tool( """Read the contents of a text file. Args: - description: Explain why you are viewing this file in short words. + description: Explain why you are viewing this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. path: The **absolute** path to the file to read. view_range: The range of lines to view. The range is inclusive and starts at 1. For example, (1, 10) will view the first 10 lines of the file. """ try: - # TODO: get sandbox ID from LangGraph's context - sandbox = get_sandbox_provider().get("local") + sandbox = sandbox_from_runtime(runtime) content = sandbox.read_file(path) if not content: return "(empty)" @@ -72,6 +87,7 @@ def read_file_tool( @tool("write_file", parse_docstring=True) def write_file_tool( + runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str, content: str, @@ -80,13 +96,12 @@ def write_file_tool( """Write text content to a file. Args: - description: Explain why you are writing to this file in short words. - path: The **absolute** path to the file to write to. - content: The content to write to the file. + description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. + path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND. + content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD. """ try: - # TODO: get sandbox ID from LangGraph's context - sandbox = get_sandbox_provider().get("local") + sandbox = sandbox_from_runtime(runtime) sandbox.write_file(path, content, append) return "OK" except Exception as e: @@ -95,6 +110,7 @@ def write_file_tool( @tool("str_replace", parse_docstring=True) def str_replace_tool( + runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str, old_str: str, @@ -105,15 +121,14 @@ def str_replace_tool( If `replace_all` is False (default), the substring to replace must appear **exactly once** in the file. Args: - description: Explain why you are replacing the substring in short words. - path: The **absolute** path to the file to replace the substring in. - old_str: The substring to replace. - new_str: The new substring. + description: Explain why you are replacing the substring in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. + path: The **absolute** path to the file to replace the substring in. ALWAYS PROVIDE THIS PARAMETER SECOND. + old_str: The substring to replace. ALWAYS PROVIDE THIS PARAMETER THIRD. + new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH. replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False. """ try: - # TODO: get sandbox ID from LangGraph's context - sandbox = get_sandbox_provider().get("local") + sandbox = sandbox_from_runtime(runtime) content = sandbox.read_file(path) if not content: return "OK"