feat: add artifacts logic (#8)

This commit is contained in:
DanielWalnut
2026-01-16 23:04:38 +08:00
committed by GitHub
parent 0d11b21c84
commit d5b3052cda
6 changed files with 129 additions and 7 deletions

View File

@@ -17,3 +17,4 @@ class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None] sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None] thread_data: NotRequired[ThreadDataState | None]
title: NotRequired[str | None] title: NotRequired[str | None]
artifacts: NotRequired[list[str] | None]

View File

@@ -50,6 +50,7 @@ class AioSandboxProvider(SandboxProvider):
self._sandboxes: dict[str, AioSandbox] = {} self._sandboxes: dict[str, AioSandbox] = {}
self._containers: dict[str, str] = {} # sandbox_id -> container_id self._containers: dict[str, str] = {} # sandbox_id -> container_id
self._ports: dict[str, int] = {} # sandbox_id -> port self._ports: dict[str, int] = {} # sandbox_id -> port
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id (for reusing sandbox across turns)
self._config = self._load_config() self._config = self._load_config()
self._shutdown_called = False self._shutdown_called = False
@@ -230,16 +231,33 @@ class AioSandboxProvider(SandboxProvider):
If base_url is configured, uses the existing sandbox. If base_url is configured, uses the existing sandbox.
Otherwise, starts a new Docker container. Otherwise, starts a new Docker container.
For the same thread_id, this method will return the same sandbox_id,
allowing sandbox reuse across multiple turns in a conversation.
This method is thread-safe. This method is thread-safe.
Args: Args:
thread_id: Optional thread ID for thread-specific configurations. thread_id: Optional thread ID for thread-specific configurations.
If provided, the sandbox will be configured with thread-specific If provided, the sandbox will be configured with thread-specific
mounts for workspace, uploads, and outputs directories. mounts for workspace, uploads, and outputs directories.
The same thread_id will reuse the same sandbox.
Returns: Returns:
The ID of the acquired sandbox environment. The ID of the acquired sandbox environment.
""" """
# Check if we already have a sandbox for this thread
if thread_id:
with self._lock:
if thread_id in self._thread_sandboxes:
existing_sandbox_id = self._thread_sandboxes[thread_id]
# Verify the sandbox still exists
if existing_sandbox_id in self._sandboxes:
logger.info(f"Reusing existing sandbox {existing_sandbox_id} for thread {thread_id}")
return existing_sandbox_id
else:
# Sandbox was released, remove stale mapping
del self._thread_sandboxes[thread_id]
sandbox_id = str(uuid.uuid4())[:8] sandbox_id = str(uuid.uuid4())[:8]
# Get thread-specific mounts if thread_id is provided # Get thread-specific mounts if thread_id is provided
@@ -265,6 +283,8 @@ class AioSandboxProvider(SandboxProvider):
sandbox = AioSandbox(id=sandbox_id, base_url=base_url) sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
with self._lock: with self._lock:
self._sandboxes[sandbox_id] = sandbox self._sandboxes[sandbox_id] = sandbox
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
return sandbox_id return sandbox_id
# Otherwise, start a new container # Otherwise, start a new container
@@ -294,7 +314,9 @@ class AioSandboxProvider(SandboxProvider):
self._sandboxes[sandbox_id] = sandbox self._sandboxes[sandbox_id] = sandbox
self._containers[sandbox_id] = container_id self._containers[sandbox_id] = container_id
self._ports[sandbox_id] = port self._ports[sandbox_id] = port
logger.info(f"Acquired sandbox {sandbox_id} at {base_url}") if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Acquired sandbox {sandbox_id} for thread {thread_id} at {base_url}")
return sandbox_id return sandbox_id
def get(self, sandbox_id: str) -> Sandbox | None: def get(self, sandbox_id: str) -> Sandbox | None:
@@ -330,6 +352,11 @@ class AioSandboxProvider(SandboxProvider):
del self._sandboxes[sandbox_id] del self._sandboxes[sandbox_id]
logger.info(f"Released sandbox {sandbox_id}") logger.info(f"Released sandbox {sandbox_id}")
# Remove thread_id -> sandbox_id mapping
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
# Get container and port info while holding the lock # Get container and port info while holding the lock
if sandbox_id in self._containers: if sandbox_id in self._containers:
container_id = self._containers.pop(sandbox_id) container_id = self._containers.pop(sandbox_id)

View File

@@ -6,7 +6,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from src.gateway.config import get_gateway_config from src.gateway.config import get_gateway_config
from src.gateway.routers import models, proxy from src.gateway.routers import artifacts, models, proxy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,6 +49,9 @@ def create_app() -> FastAPI:
# Models API is mounted at /api/models # Models API is mounted at /api/models
app.include_router(models.router) app.include_router(models.router)
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
app.include_router(artifacts.router)
# Proxy router handles all LangGraph paths (must be last due to catch-all) # Proxy router handles all LangGraph paths (must be last due to catch-all)
app.include_router(proxy.router) app.include_router(proxy.router)

View File

@@ -1,3 +1,3 @@
from . import models, proxy from . import artifacts, models, proxy
__all__ = ["models", "proxy"] __all__ = ["artifacts", "models", "proxy"]

View File

@@ -0,0 +1,75 @@
import os
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse
# Base directory for thread data (relative to backend/)
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
# Virtual path prefix used in sandbox environments (without leading slash for URL path matching)
VIRTUAL_PATH_PREFIX = "mnt/user-data"
router = APIRouter(prefix="/api", tags=["artifacts"])
def _resolve_artifact_path(thread_id: str, artifact_path: str) -> Path:
"""Resolve a virtual artifact path to the actual filesystem path.
Args:
thread_id: The thread ID.
artifact_path: The virtual path (e.g., mnt/user-data/outputs/file.txt).
Returns:
The resolved filesystem path.
Raises:
HTTPException: If the path is invalid or outside allowed directories.
"""
# Validate and remove virtual path prefix
if not artifact_path.startswith(VIRTUAL_PATH_PREFIX):
raise HTTPException(status_code=400, detail=f"Path must start with /{VIRTUAL_PATH_PREFIX}")
relative_path = artifact_path[len(VIRTUAL_PATH_PREFIX) :].lstrip("/")
# Build the actual path
base_dir = Path(os.getcwd()) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
actual_path = base_dir / relative_path
# Security check: ensure the path is within the thread's user-data directory
try:
actual_path = actual_path.resolve()
base_dir = base_dir.resolve()
if not str(actual_path).startswith(str(base_dir)):
raise HTTPException(status_code=403, detail="Access denied: path traversal detected")
except (ValueError, RuntimeError):
raise HTTPException(status_code=400, detail="Invalid path")
return actual_path
@router.get("/threads/{thread_id}/artifacts/{path:path}")
async def get_artifact(thread_id: str, path: str) -> FileResponse:
"""Get an artifact file by its path.
Args:
thread_id: The thread ID.
path: The artifact path with virtual prefix (e.g., mnt/user-data/outputs/file.txt).
Returns:
The file content as a FileResponse.
Raises:
HTTPException: 404 if file not found, 403 if access denied.
"""
actual_path = _resolve_artifact_path(thread_id, path)
if not actual_path.exists():
raise HTTPException(status_code=404, detail=f"Artifact not found: {path}")
if not actual_path.is_file():
raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
return FileResponse(
path=actual_path,
filename=actual_path.name,
)

View File

@@ -1,8 +1,19 @@
from langchain.tools import tool from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from langgraph.typing import ContextT
from src.agents.thread_state import ThreadState
@tool("present_files", parse_docstring=True) @tool("present_files", parse_docstring=True)
def present_file_tool(filepaths: list[str]) -> str: def present_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
filepaths: list[str],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
"""Make files visible to the user for viewing and rendering in the client interface. """Make files visible to the user for viewing and rendering in the client interface.
When to use the present_files tool: When to use the present_files tool:
@@ -22,4 +33,9 @@ def present_file_tool(filepaths: list[str]) -> str:
Args: Args:
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented. filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
""" """
return "OK" existing_artifacts = runtime.state.get("artifacts") or []
new_artifacts = existing_artifacts + filepaths
runtime.state["artifacts"] = new_artifacts
return Command(
update={"artifacts": new_artifacts, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]},
)