mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-26 15:24:48 +08:00
feat: integrated with sandbox
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
from .lead_agent import lead_agent
|
from .lead_agent import lead_agent
|
||||||
|
from .thread_state import SandboxState, ThreadState
|
||||||
|
|
||||||
__all__ = ["lead_agent"]
|
__all__ = ["lead_agent", "SandboxState", "ThreadState"]
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
|
|
||||||
from src.agents.lead_agent.prompt import apply_prompt_template
|
from src.agents.lead_agent.prompt import apply_prompt_template
|
||||||
|
from src.agents.thread_state import ThreadState
|
||||||
from src.models import create_chat_model
|
from src.models import create_chat_model
|
||||||
|
from src.sandbox.middleware import SandboxMiddleware
|
||||||
from src.tools import get_available_tools
|
from src.tools import get_available_tools
|
||||||
|
|
||||||
lead_agent = create_agent(
|
lead_agent = create_agent(
|
||||||
model=create_chat_model(thinking_enabled=True),
|
model=create_chat_model(thinking_enabled=True),
|
||||||
tools=get_available_tools(),
|
tools=get_available_tools(),
|
||||||
|
middleware=[SandboxMiddleware()],
|
||||||
system_prompt=apply_prompt_template(),
|
system_prompt=apply_prompt_template(),
|
||||||
|
state_schema=ThreadState,
|
||||||
)
|
)
|
||||||
|
|||||||
11
backend/src/agents/thread_state.py
Normal file
11
backend/src/agents/thread_state.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from langchain.agents import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxState(TypedDict):
|
||||||
|
sandbox_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadState(AgentState):
|
||||||
|
sandbox: SandboxState | None = None
|
||||||
@@ -26,7 +26,7 @@ class AppConfig(BaseModel):
|
|||||||
Priority:
|
Priority:
|
||||||
1. If provided `config_path` argument, use it.
|
1. If provided `config_path` argument, use it.
|
||||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||||
3. Otherwise, first check the `config.yaml` in the current directory, then use `config.yaml` in the parent directory.
|
3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory.
|
||||||
"""
|
"""
|
||||||
if config_path:
|
if config_path:
|
||||||
path = Path(config_path)
|
path = Path(config_path)
|
||||||
@@ -39,10 +39,13 @@ class AppConfig(BaseModel):
|
|||||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||||
return path
|
return path
|
||||||
else:
|
else:
|
||||||
# Check if the config.yaml is in the parent directory of CWD
|
# Check if the config.yaml is in the current directory
|
||||||
path = Path(os.getcwd()).parent / "config.yaml"
|
path = Path(os.getcwd()) / "config.yaml"
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError(f"Config file not found at {path}")
|
# Check if the config.yaml is in the parent directory of CWD
|
||||||
|
path = Path(os.getcwd()).parent / "config.yaml"
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory")
|
||||||
return path
|
return path
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -13,9 +13,11 @@ class LocalSandboxProvider(SandboxProvider):
|
|||||||
return _singleton.id
|
return _singleton.id
|
||||||
|
|
||||||
def get(self, sandbox_id: str) -> None:
|
def get(self, sandbox_id: str) -> None:
|
||||||
if _singleton is None:
|
if sandbox_id == "local":
|
||||||
self.acquire()
|
if _singleton is None:
|
||||||
return _singleton
|
self.acquire()
|
||||||
|
return _singleton
|
||||||
|
return None
|
||||||
|
|
||||||
def release(self, sandbox_id: str) -> None:
|
def release(self, sandbox_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
33
backend/src/sandbox/middleware.py
Normal file
33
backend/src/sandbox/middleware.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from typing import override
|
||||||
|
|
||||||
|
from langchain.agents import AgentState
|
||||||
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
|
from langgraph.runtime import Runtime
|
||||||
|
|
||||||
|
from src.agents.thread_state import SandboxState
|
||||||
|
from src.sandbox import get_sandbox_provider
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxMiddlewareState(AgentState):
|
||||||
|
"""Compatible with the `ThreadState` schema."""
|
||||||
|
|
||||||
|
sandbox: SandboxState | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||||
|
"""Create a sandbox environment and assign it to an agent."""
|
||||||
|
|
||||||
|
state_schema = SandboxMiddlewareState
|
||||||
|
|
||||||
|
def _acquire_sandbox(self) -> str:
|
||||||
|
provider = get_sandbox_provider()
|
||||||
|
sandbox_id = provider.acquire()
|
||||||
|
print(f"Acquiring sandbox {sandbox_id}")
|
||||||
|
return sandbox_id
|
||||||
|
|
||||||
|
@override
|
||||||
|
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||||
|
if "sandbox" not in state or state["sandbox"] is None:
|
||||||
|
sandbox_id = self._acquire_sandbox()
|
||||||
|
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||||
|
return super().before_agent(state, runtime)
|
||||||
@@ -10,7 +10,7 @@ class SandboxProvider(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def acquire(self) -> str:
|
def acquire(self) -> str:
|
||||||
"""Acquire a sandbox environment.
|
"""Acquire a sandbox environment and return its ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The ID of the acquired sandbox environment.
|
The ID of the acquired sandbox environment.
|
||||||
@@ -18,7 +18,7 @@ class SandboxProvider(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, sandbox_id: str) -> Sandbox:
|
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||||
"""Get a sandbox environment by ID.
|
"""Get a sandbox environment by ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,10 +1,28 @@
|
|||||||
from langchain.tools import tool
|
from langchain.tools import ToolRuntime, tool
|
||||||
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
|
from src.agents.thread_state import ThreadState
|
||||||
|
from src.sandbox.sandbox import Sandbox
|
||||||
from src.sandbox.sandbox_provider import get_sandbox_provider
|
from src.sandbox.sandbox_provider import get_sandbox_provider
|
||||||
|
|
||||||
|
|
||||||
|
def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox:
|
||||||
|
if runtime is None:
|
||||||
|
raise ValueError("No sandbox found: No runtime found")
|
||||||
|
sandbox_state = runtime.state.get("sandbox")
|
||||||
|
if sandbox_state is None:
|
||||||
|
raise ValueError("No sandbox found: No sandbox state found in runtime")
|
||||||
|
sandbox_id = sandbox_state.get("sandbox_id")
|
||||||
|
if sandbox_id is None:
|
||||||
|
raise ValueError("No sandbox ID found: No sandbox ID found in sandbox state")
|
||||||
|
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||||
|
if sandbox is None:
|
||||||
|
raise ValueError(f"No sandbox found: sandbox with ID {sandbox_id} not found")
|
||||||
|
return sandbox
|
||||||
|
|
||||||
|
|
||||||
@tool("bash", parse_docstring=True)
|
@tool("bash", parse_docstring=True)
|
||||||
def bash_tool(description: str, command: str) -> str:
|
def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str:
|
||||||
"""Execute a bash command in a Linux environment.
|
"""Execute a bash command in a Linux environment.
|
||||||
|
|
||||||
|
|
||||||
@@ -12,29 +30,26 @@ def bash_tool(description: str, command: str) -> str:
|
|||||||
- Use `pip install` to install Python packages.
|
- Use `pip install` to install Python packages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Explain why you are running this command in short words.
|
description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||||
command: The bash command to execute. Always use absolute paths for files and directories.
|
command: The bash command to execute. Always use absolute paths for files and directories.
|
||||||
"""
|
"""
|
||||||
# TODO: get sandbox ID from LangGraph's context
|
|
||||||
sandbox_id = "local"
|
|
||||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
|
||||||
try:
|
try:
|
||||||
|
sandbox = sandbox_from_runtime(runtime)
|
||||||
return sandbox.execute_command(command)
|
return sandbox.execute_command(command)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
@tool("ls", parse_docstring=True)
|
@tool("ls", parse_docstring=True)
|
||||||
def ls_tool(description: str, path: str) -> str:
|
def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str:
|
||||||
"""List the contents of a directory up to 2 levels deep in tree format.
|
"""List the contents of a directory up to 2 levels deep in tree format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Explain why you are listing this directory in short words.
|
description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||||
path: The **absolute** path to the directory to list.
|
path: The **absolute** path to the directory to list.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# TODO: get sandbox ID from LangGraph's context
|
sandbox = sandbox_from_runtime(runtime)
|
||||||
sandbox = get_sandbox_provider().get("local")
|
|
||||||
children = sandbox.list_dir(path)
|
children = sandbox.list_dir(path)
|
||||||
if not children:
|
if not children:
|
||||||
return "(empty)"
|
return "(empty)"
|
||||||
@@ -45,6 +60,7 @@ def ls_tool(description: str, path: str) -> str:
|
|||||||
|
|
||||||
@tool("read_file", parse_docstring=True)
|
@tool("read_file", parse_docstring=True)
|
||||||
def read_file_tool(
|
def read_file_tool(
|
||||||
|
runtime: ToolRuntime[ContextT, ThreadState],
|
||||||
description: str,
|
description: str,
|
||||||
path: str,
|
path: str,
|
||||||
view_range: tuple[int, int] | None = None,
|
view_range: tuple[int, int] | None = None,
|
||||||
@@ -52,13 +68,12 @@ def read_file_tool(
|
|||||||
"""Read the contents of a text file.
|
"""Read the contents of a text file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Explain why you are viewing this file in short words.
|
description: Explain why you are viewing this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||||
path: The **absolute** path to the file to read.
|
path: The **absolute** path to the file to read.
|
||||||
view_range: The range of lines to view. The range is inclusive and starts at 1. For example, (1, 10) will view the first 10 lines of the file.
|
view_range: The range of lines to view. The range is inclusive and starts at 1. For example, (1, 10) will view the first 10 lines of the file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# TODO: get sandbox ID from LangGraph's context
|
sandbox = sandbox_from_runtime(runtime)
|
||||||
sandbox = get_sandbox_provider().get("local")
|
|
||||||
content = sandbox.read_file(path)
|
content = sandbox.read_file(path)
|
||||||
if not content:
|
if not content:
|
||||||
return "(empty)"
|
return "(empty)"
|
||||||
@@ -72,6 +87,7 @@ def read_file_tool(
|
|||||||
|
|
||||||
@tool("write_file", parse_docstring=True)
|
@tool("write_file", parse_docstring=True)
|
||||||
def write_file_tool(
|
def write_file_tool(
|
||||||
|
runtime: ToolRuntime[ContextT, ThreadState],
|
||||||
description: str,
|
description: str,
|
||||||
path: str,
|
path: str,
|
||||||
content: str,
|
content: str,
|
||||||
@@ -80,13 +96,12 @@ def write_file_tool(
|
|||||||
"""Write text content to a file.
|
"""Write text content to a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Explain why you are writing to this file in short words.
|
description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||||
path: The **absolute** path to the file to write to.
|
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||||
content: The content to write to the file.
|
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# TODO: get sandbox ID from LangGraph's context
|
sandbox = sandbox_from_runtime(runtime)
|
||||||
sandbox = get_sandbox_provider().get("local")
|
|
||||||
sandbox.write_file(path, content, append)
|
sandbox.write_file(path, content, append)
|
||||||
return "OK"
|
return "OK"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -95,6 +110,7 @@ def write_file_tool(
|
|||||||
|
|
||||||
@tool("str_replace", parse_docstring=True)
|
@tool("str_replace", parse_docstring=True)
|
||||||
def str_replace_tool(
|
def str_replace_tool(
|
||||||
|
runtime: ToolRuntime[ContextT, ThreadState],
|
||||||
description: str,
|
description: str,
|
||||||
path: str,
|
path: str,
|
||||||
old_str: str,
|
old_str: str,
|
||||||
@@ -105,15 +121,14 @@ def str_replace_tool(
|
|||||||
If `replace_all` is False (default), the substring to replace must appear **exactly once** in the file.
|
If `replace_all` is False (default), the substring to replace must appear **exactly once** in the file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: Explain why you are replacing the substring in short words.
|
description: Explain why you are replacing the substring in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||||
path: The **absolute** path to the file to replace the substring in.
|
path: The **absolute** path to the file to replace the substring in. ALWAYS PROVIDE THIS PARAMETER SECOND.
|
||||||
old_str: The substring to replace.
|
old_str: The substring to replace. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||||
new_str: The new substring.
|
new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH.
|
||||||
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
|
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# TODO: get sandbox ID from LangGraph's context
|
sandbox = sandbox_from_runtime(runtime)
|
||||||
sandbox = get_sandbox_provider().get("local")
|
|
||||||
content = sandbox.read_file(path)
|
content = sandbox.read_file(path)
|
||||||
if not content:
|
if not content:
|
||||||
return "OK"
|
return "OK"
|
||||||
|
|||||||
Reference in New Issue
Block a user