feat: add artifacts logic (#8)

This commit is contained in:
DanielWalnut
2026-01-16 23:04:38 +08:00
committed by GitHub
parent 6464a67230
commit facde645d7
6 changed files with 129 additions and 7 deletions

View File

@@ -17,3 +17,4 @@ class ThreadState(AgentState):
sandbox: NotRequired[SandboxState | None]
thread_data: NotRequired[ThreadDataState | None]
title: NotRequired[str | None]
artifacts: NotRequired[list[str] | None]

View File

@@ -50,6 +50,7 @@ class AioSandboxProvider(SandboxProvider):
self._sandboxes: dict[str, AioSandbox] = {}
self._containers: dict[str, str] = {} # sandbox_id -> container_id
self._ports: dict[str, int] = {} # sandbox_id -> port
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id (for reusing sandbox across turns)
self._config = self._load_config()
self._shutdown_called = False
@@ -230,16 +231,33 @@ class AioSandboxProvider(SandboxProvider):
If base_url is configured, uses the existing sandbox.
Otherwise, starts a new Docker container.
For the same thread_id, this method will return the same sandbox_id,
allowing sandbox reuse across multiple turns in a conversation.
This method is thread-safe.
Args:
thread_id: Optional thread ID for thread-specific configurations.
If provided, the sandbox will be configured with thread-specific
mounts for workspace, uploads, and outputs directories.
The same thread_id will reuse the same sandbox.
Returns:
The ID of the acquired sandbox environment.
"""
# Check if we already have a sandbox for this thread
if thread_id:
with self._lock:
if thread_id in self._thread_sandboxes:
existing_sandbox_id = self._thread_sandboxes[thread_id]
# Verify the sandbox still exists
if existing_sandbox_id in self._sandboxes:
logger.info(f"Reusing existing sandbox {existing_sandbox_id} for thread {thread_id}")
return existing_sandbox_id
else:
# Sandbox was released, remove stale mapping
del self._thread_sandboxes[thread_id]
sandbox_id = str(uuid.uuid4())[:8]
# Get thread-specific mounts if thread_id is provided
@@ -265,6 +283,8 @@ class AioSandboxProvider(SandboxProvider):
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
with self._lock:
self._sandboxes[sandbox_id] = sandbox
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
return sandbox_id
# Otherwise, start a new container
@@ -294,7 +314,9 @@ class AioSandboxProvider(SandboxProvider):
self._sandboxes[sandbox_id] = sandbox
self._containers[sandbox_id] = container_id
self._ports[sandbox_id] = port
logger.info(f"Acquired sandbox {sandbox_id} at {base_url}")
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Acquired sandbox {sandbox_id} for thread {thread_id} at {base_url}")
return sandbox_id
def get(self, sandbox_id: str) -> Sandbox | None:
@@ -330,6 +352,11 @@ class AioSandboxProvider(SandboxProvider):
del self._sandboxes[sandbox_id]
logger.info(f"Released sandbox {sandbox_id}")
# Remove thread_id -> sandbox_id mapping
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
# Get container and port info while holding the lock
if sandbox_id in self._containers:
container_id = self._containers.pop(sandbox_id)

View File

@@ -6,7 +6,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.gateway.config import get_gateway_config
from src.gateway.routers import models, proxy
from src.gateway.routers import artifacts, models, proxy
logger = logging.getLogger(__name__)
@@ -49,6 +49,9 @@ def create_app() -> FastAPI:
# Models API is mounted at /api/models
app.include_router(models.router)
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
app.include_router(artifacts.router)
# Proxy router handles all LangGraph paths (must be last due to catch-all)
app.include_router(proxy.router)

View File

@@ -1,3 +1,3 @@
from . import models, proxy
from . import artifacts, models, proxy
__all__ = ["models", "proxy"]
__all__ = ["artifacts", "models", "proxy"]

View File

@@ -0,0 +1,75 @@
import os
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse
# Base directory for thread data (relative to backend/)
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
# Virtual path prefix used in sandbox environments (without leading slash for URL path matching)
VIRTUAL_PATH_PREFIX = "mnt/user-data"
router = APIRouter(prefix="/api", tags=["artifacts"])
def _resolve_artifact_path(thread_id: str, artifact_path: str) -> Path:
"""Resolve a virtual artifact path to the actual filesystem path.
Args:
thread_id: The thread ID.
artifact_path: The virtual path (e.g., mnt/user-data/outputs/file.txt).
Returns:
The resolved filesystem path.
Raises:
HTTPException: If the path is invalid or outside allowed directories.
"""
# Validate and remove virtual path prefix
if not artifact_path.startswith(VIRTUAL_PATH_PREFIX):
raise HTTPException(status_code=400, detail=f"Path must start with /{VIRTUAL_PATH_PREFIX}")
relative_path = artifact_path[len(VIRTUAL_PATH_PREFIX) :].lstrip("/")
# Build the actual path
base_dir = Path(os.getcwd()) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
actual_path = base_dir / relative_path
# Security check: ensure the path is within the thread's user-data directory
try:
actual_path = actual_path.resolve()
base_dir = base_dir.resolve()
if not str(actual_path).startswith(str(base_dir)):
raise HTTPException(status_code=403, detail="Access denied: path traversal detected")
except (ValueError, RuntimeError):
raise HTTPException(status_code=400, detail="Invalid path")
return actual_path
@router.get("/threads/{thread_id}/artifacts/{path:path}")
async def get_artifact(thread_id: str, path: str) -> FileResponse:
"""Get an artifact file by its path.
Args:
thread_id: The thread ID.
path: The artifact path with virtual prefix (e.g., mnt/user-data/outputs/file.txt).
Returns:
The file content as a FileResponse.
Raises:
HTTPException: 404 if file not found, 403 if access denied.
"""
actual_path = _resolve_artifact_path(thread_id, path)
if not actual_path.exists():
raise HTTPException(status_code=404, detail=f"Artifact not found: {path}")
if not actual_path.is_file():
raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
return FileResponse(
path=actual_path,
filename=actual_path.name,
)

View File

@@ -1,8 +1,19 @@
from langchain.tools import tool
from typing import Annotated
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langchain_core.messages import ToolMessage
from langgraph.types import Command
from langgraph.typing import ContextT
from src.agents.thread_state import ThreadState
@tool("present_files", parse_docstring=True)
def present_file_tool(filepaths: list[str]) -> str:
def present_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
filepaths: list[str],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
"""Make files visible to the user for viewing and rendering in the client interface.
When to use the present_files tool:
@@ -22,4 +33,9 @@ def present_file_tool(filepaths: list[str]) -> str:
Args:
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
"""
return "OK"
existing_artifacts = runtime.state.get("artifacts") or []
new_artifacts = existing_artifacts + filepaths
runtime.state["artifacts"] = new_artifacts
return Command(
update={"artifacts": new_artifacts, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]},
)