mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-21 13:24:44 +08:00
feat: add artifacts logic (#8)
This commit is contained in:
@@ -17,3 +17,4 @@ class ThreadState(AgentState):
|
|||||||
sandbox: NotRequired[SandboxState | None]
|
sandbox: NotRequired[SandboxState | None]
|
||||||
thread_data: NotRequired[ThreadDataState | None]
|
thread_data: NotRequired[ThreadDataState | None]
|
||||||
title: NotRequired[str | None]
|
title: NotRequired[str | None]
|
||||||
|
artifacts: NotRequired[list[str] | None]
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
self._sandboxes: dict[str, AioSandbox] = {}
|
self._sandboxes: dict[str, AioSandbox] = {}
|
||||||
self._containers: dict[str, str] = {} # sandbox_id -> container_id
|
self._containers: dict[str, str] = {} # sandbox_id -> container_id
|
||||||
self._ports: dict[str, int] = {} # sandbox_id -> port
|
self._ports: dict[str, int] = {} # sandbox_id -> port
|
||||||
|
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id (for reusing sandbox across turns)
|
||||||
self._config = self._load_config()
|
self._config = self._load_config()
|
||||||
self._shutdown_called = False
|
self._shutdown_called = False
|
||||||
|
|
||||||
@@ -230,16 +231,33 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
If base_url is configured, uses the existing sandbox.
|
If base_url is configured, uses the existing sandbox.
|
||||||
Otherwise, starts a new Docker container.
|
Otherwise, starts a new Docker container.
|
||||||
|
|
||||||
|
For the same thread_id, this method will return the same sandbox_id,
|
||||||
|
allowing sandbox reuse across multiple turns in a conversation.
|
||||||
|
|
||||||
This method is thread-safe.
|
This method is thread-safe.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thread_id: Optional thread ID for thread-specific configurations.
|
thread_id: Optional thread ID for thread-specific configurations.
|
||||||
If provided, the sandbox will be configured with thread-specific
|
If provided, the sandbox will be configured with thread-specific
|
||||||
mounts for workspace, uploads, and outputs directories.
|
mounts for workspace, uploads, and outputs directories.
|
||||||
|
The same thread_id will reuse the same sandbox.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The ID of the acquired sandbox environment.
|
The ID of the acquired sandbox environment.
|
||||||
"""
|
"""
|
||||||
|
# Check if we already have a sandbox for this thread
|
||||||
|
if thread_id:
|
||||||
|
with self._lock:
|
||||||
|
if thread_id in self._thread_sandboxes:
|
||||||
|
existing_sandbox_id = self._thread_sandboxes[thread_id]
|
||||||
|
# Verify the sandbox still exists
|
||||||
|
if existing_sandbox_id in self._sandboxes:
|
||||||
|
logger.info(f"Reusing existing sandbox {existing_sandbox_id} for thread {thread_id}")
|
||||||
|
return existing_sandbox_id
|
||||||
|
else:
|
||||||
|
# Sandbox was released, remove stale mapping
|
||||||
|
del self._thread_sandboxes[thread_id]
|
||||||
|
|
||||||
sandbox_id = str(uuid.uuid4())[:8]
|
sandbox_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
# Get thread-specific mounts if thread_id is provided
|
# Get thread-specific mounts if thread_id is provided
|
||||||
@@ -265,6 +283,8 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
|
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._sandboxes[sandbox_id] = sandbox
|
self._sandboxes[sandbox_id] = sandbox
|
||||||
|
if thread_id:
|
||||||
|
self._thread_sandboxes[thread_id] = sandbox_id
|
||||||
return sandbox_id
|
return sandbox_id
|
||||||
|
|
||||||
# Otherwise, start a new container
|
# Otherwise, start a new container
|
||||||
@@ -294,7 +314,9 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
self._sandboxes[sandbox_id] = sandbox
|
self._sandboxes[sandbox_id] = sandbox
|
||||||
self._containers[sandbox_id] = container_id
|
self._containers[sandbox_id] = container_id
|
||||||
self._ports[sandbox_id] = port
|
self._ports[sandbox_id] = port
|
||||||
logger.info(f"Acquired sandbox {sandbox_id} at {base_url}")
|
if thread_id:
|
||||||
|
self._thread_sandboxes[thread_id] = sandbox_id
|
||||||
|
logger.info(f"Acquired sandbox {sandbox_id} for thread {thread_id} at {base_url}")
|
||||||
return sandbox_id
|
return sandbox_id
|
||||||
|
|
||||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||||
@@ -330,6 +352,11 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
del self._sandboxes[sandbox_id]
|
del self._sandboxes[sandbox_id]
|
||||||
logger.info(f"Released sandbox {sandbox_id}")
|
logger.info(f"Released sandbox {sandbox_id}")
|
||||||
|
|
||||||
|
# Remove thread_id -> sandbox_id mapping
|
||||||
|
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||||
|
for tid in thread_ids_to_remove:
|
||||||
|
del self._thread_sandboxes[tid]
|
||||||
|
|
||||||
# Get container and port info while holding the lock
|
# Get container and port info while holding the lock
|
||||||
if sandbox_id in self._containers:
|
if sandbox_id in self._containers:
|
||||||
container_id = self._containers.pop(sandbox_id)
|
container_id = self._containers.pop(sandbox_id)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from src.gateway.config import get_gateway_config
|
from src.gateway.config import get_gateway_config
|
||||||
from src.gateway.routers import models, proxy
|
from src.gateway.routers import artifacts, models, proxy
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -49,6 +49,9 @@ def create_app() -> FastAPI:
|
|||||||
# Models API is mounted at /api/models
|
# Models API is mounted at /api/models
|
||||||
app.include_router(models.router)
|
app.include_router(models.router)
|
||||||
|
|
||||||
|
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
|
||||||
|
app.include_router(artifacts.router)
|
||||||
|
|
||||||
# Proxy router handles all LangGraph paths (must be last due to catch-all)
|
# Proxy router handles all LangGraph paths (must be last due to catch-all)
|
||||||
app.include_router(proxy.router)
|
app.include_router(proxy.router)
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
from . import models, proxy
|
from . import artifacts, models, proxy
|
||||||
|
|
||||||
__all__ = ["models", "proxy"]
|
__all__ = ["artifacts", "models", "proxy"]
|
||||||
|
|||||||
75
backend/src/gateway/routers/artifacts.py
Normal file
75
backend/src/gateway/routers/artifacts.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
# Base directory for thread data (relative to backend/)
|
||||||
|
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
||||||
|
|
||||||
|
# Virtual path prefix used in sandbox environments (without leading slash for URL path matching)
|
||||||
|
VIRTUAL_PATH_PREFIX = "mnt/user-data"
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["artifacts"])
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_artifact_path(thread_id: str, artifact_path: str) -> Path:
|
||||||
|
"""Resolve a virtual artifact path to the actual filesystem path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The thread ID.
|
||||||
|
artifact_path: The virtual path (e.g., mnt/user-data/outputs/file.txt).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resolved filesystem path.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If the path is invalid or outside allowed directories.
|
||||||
|
"""
|
||||||
|
# Validate and remove virtual path prefix
|
||||||
|
if not artifact_path.startswith(VIRTUAL_PATH_PREFIX):
|
||||||
|
raise HTTPException(status_code=400, detail=f"Path must start with /{VIRTUAL_PATH_PREFIX}")
|
||||||
|
relative_path = artifact_path[len(VIRTUAL_PATH_PREFIX) :].lstrip("/")
|
||||||
|
|
||||||
|
# Build the actual path
|
||||||
|
base_dir = Path(os.getcwd()) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||||
|
actual_path = base_dir / relative_path
|
||||||
|
|
||||||
|
# Security check: ensure the path is within the thread's user-data directory
|
||||||
|
try:
|
||||||
|
actual_path = actual_path.resolve()
|
||||||
|
base_dir = base_dir.resolve()
|
||||||
|
if not str(actual_path).startswith(str(base_dir)):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied: path traversal detected")
|
||||||
|
except (ValueError, RuntimeError):
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid path")
|
||||||
|
|
||||||
|
return actual_path
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/artifacts/{path:path}")
|
||||||
|
async def get_artifact(thread_id: str, path: str) -> FileResponse:
|
||||||
|
"""Get an artifact file by its path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: The thread ID.
|
||||||
|
path: The artifact path with virtual prefix (e.g., mnt/user-data/outputs/file.txt).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The file content as a FileResponse.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if file not found, 403 if access denied.
|
||||||
|
"""
|
||||||
|
actual_path = _resolve_artifact_path(thread_id, path)
|
||||||
|
|
||||||
|
if not actual_path.exists():
|
||||||
|
raise HTTPException(status_code=404, detail=f"Artifact not found: {path}")
|
||||||
|
|
||||||
|
if not actual_path.is_file():
|
||||||
|
raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=actual_path,
|
||||||
|
filename=actual_path.name,
|
||||||
|
)
|
||||||
@@ -1,8 +1,19 @@
|
|||||||
from langchain.tools import tool
|
from typing import Annotated
|
||||||
|
|
||||||
|
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||||
|
from langchain_core.messages import ToolMessage
|
||||||
|
from langgraph.types import Command
|
||||||
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
|
from src.agents.thread_state import ThreadState
|
||||||
|
|
||||||
|
|
||||||
@tool("present_files", parse_docstring=True)
|
@tool("present_files", parse_docstring=True)
|
||||||
def present_file_tool(filepaths: list[str]) -> str:
|
def present_file_tool(
|
||||||
|
runtime: ToolRuntime[ContextT, ThreadState],
|
||||||
|
filepaths: list[str],
|
||||||
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||||
|
) -> Command:
|
||||||
"""Make files visible to the user for viewing and rendering in the client interface.
|
"""Make files visible to the user for viewing and rendering in the client interface.
|
||||||
|
|
||||||
When to use the present_files tool:
|
When to use the present_files tool:
|
||||||
@@ -22,4 +33,9 @@ def present_file_tool(filepaths: list[str]) -> str:
|
|||||||
Args:
|
Args:
|
||||||
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
|
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
|
||||||
"""
|
"""
|
||||||
return "OK"
|
existing_artifacts = runtime.state.get("artifacts") or []
|
||||||
|
new_artifacts = existing_artifacts + filepaths
|
||||||
|
runtime.state["artifacts"] = new_artifacts
|
||||||
|
return Command(
|
||||||
|
update={"artifacts": new_artifacts, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user