mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 14:22:13 +08:00
feat: add artifacts logic (#8)
This commit is contained in:
@@ -17,3 +17,4 @@ class ThreadState(AgentState):
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
title: NotRequired[str | None]
|
||||
artifacts: NotRequired[list[str] | None]
|
||||
|
||||
@@ -50,6 +50,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
self._sandboxes: dict[str, AioSandbox] = {}
|
||||
self._containers: dict[str, str] = {} # sandbox_id -> container_id
|
||||
self._ports: dict[str, int] = {} # sandbox_id -> port
|
||||
self._thread_sandboxes: dict[str, str] = {} # thread_id -> sandbox_id (for reusing sandbox across turns)
|
||||
self._config = self._load_config()
|
||||
self._shutdown_called = False
|
||||
|
||||
@@ -230,16 +231,33 @@ class AioSandboxProvider(SandboxProvider):
|
||||
If base_url is configured, uses the existing sandbox.
|
||||
Otherwise, starts a new Docker container.
|
||||
|
||||
For the same thread_id, this method will return the same sandbox_id,
|
||||
allowing sandbox reuse across multiple turns in a conversation.
|
||||
|
||||
This method is thread-safe.
|
||||
|
||||
Args:
|
||||
thread_id: Optional thread ID for thread-specific configurations.
|
||||
If provided, the sandbox will be configured with thread-specific
|
||||
mounts for workspace, uploads, and outputs directories.
|
||||
The same thread_id will reuse the same sandbox.
|
||||
|
||||
Returns:
|
||||
The ID of the acquired sandbox environment.
|
||||
"""
|
||||
# Check if we already have a sandbox for this thread
|
||||
if thread_id:
|
||||
with self._lock:
|
||||
if thread_id in self._thread_sandboxes:
|
||||
existing_sandbox_id = self._thread_sandboxes[thread_id]
|
||||
# Verify the sandbox still exists
|
||||
if existing_sandbox_id in self._sandboxes:
|
||||
logger.info(f"Reusing existing sandbox {existing_sandbox_id} for thread {thread_id}")
|
||||
return existing_sandbox_id
|
||||
else:
|
||||
# Sandbox was released, remove stale mapping
|
||||
del self._thread_sandboxes[thread_id]
|
||||
|
||||
sandbox_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Get thread-specific mounts if thread_id is provided
|
||||
@@ -265,6 +283,8 @@ class AioSandboxProvider(SandboxProvider):
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=base_url)
|
||||
with self._lock:
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
if thread_id:
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
return sandbox_id
|
||||
|
||||
# Otherwise, start a new container
|
||||
@@ -294,7 +314,9 @@ class AioSandboxProvider(SandboxProvider):
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._containers[sandbox_id] = container_id
|
||||
self._ports[sandbox_id] = port
|
||||
logger.info(f"Acquired sandbox {sandbox_id} at {base_url}")
|
||||
if thread_id:
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
logger.info(f"Acquired sandbox {sandbox_id} for thread {thread_id} at {base_url}")
|
||||
return sandbox_id
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
@@ -330,6 +352,11 @@ class AioSandboxProvider(SandboxProvider):
|
||||
del self._sandboxes[sandbox_id]
|
||||
logger.info(f"Released sandbox {sandbox_id}")
|
||||
|
||||
# Remove thread_id -> sandbox_id mapping
|
||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids_to_remove:
|
||||
del self._thread_sandboxes[tid]
|
||||
|
||||
# Get container and port info while holding the lock
|
||||
if sandbox_id in self._containers:
|
||||
container_id = self._containers.pop(sandbox_id)
|
||||
|
||||
@@ -6,7 +6,7 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from src.gateway.config import get_gateway_config
|
||||
from src.gateway.routers import models, proxy
|
||||
from src.gateway.routers import artifacts, models, proxy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,6 +49,9 @@ def create_app() -> FastAPI:
|
||||
# Models API is mounted at /api/models
|
||||
app.include_router(models.router)
|
||||
|
||||
# Artifacts API is mounted at /api/threads/{thread_id}/artifacts
|
||||
app.include_router(artifacts.router)
|
||||
|
||||
# Proxy router handles all LangGraph paths (must be last due to catch-all)
|
||||
app.include_router(proxy.router)
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from . import models, proxy
|
||||
from . import artifacts, models, proxy
|
||||
|
||||
__all__ = ["models", "proxy"]
|
||||
__all__ = ["artifacts", "models", "proxy"]
|
||||
|
||||
75
backend/src/gateway/routers/artifacts.py
Normal file
75
backend/src/gateway/routers/artifacts.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
# Base directory for thread data (relative to backend/)
|
||||
THREAD_DATA_BASE_DIR = ".deer-flow/threads"
|
||||
|
||||
# Virtual path prefix used in sandbox environments (without leading slash for URL path matching)
|
||||
VIRTUAL_PATH_PREFIX = "mnt/user-data"
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["artifacts"])
|
||||
|
||||
|
||||
def _resolve_artifact_path(thread_id: str, artifact_path: str) -> Path:
|
||||
"""Resolve a virtual artifact path to the actual filesystem path.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
artifact_path: The virtual path (e.g., mnt/user-data/outputs/file.txt).
|
||||
|
||||
Returns:
|
||||
The resolved filesystem path.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the path is invalid or outside allowed directories.
|
||||
"""
|
||||
# Validate and remove virtual path prefix
|
||||
if not artifact_path.startswith(VIRTUAL_PATH_PREFIX):
|
||||
raise HTTPException(status_code=400, detail=f"Path must start with /{VIRTUAL_PATH_PREFIX}")
|
||||
relative_path = artifact_path[len(VIRTUAL_PATH_PREFIX) :].lstrip("/")
|
||||
|
||||
# Build the actual path
|
||||
base_dir = Path(os.getcwd()) / THREAD_DATA_BASE_DIR / thread_id / "user-data"
|
||||
actual_path = base_dir / relative_path
|
||||
|
||||
# Security check: ensure the path is within the thread's user-data directory
|
||||
try:
|
||||
actual_path = actual_path.resolve()
|
||||
base_dir = base_dir.resolve()
|
||||
if not str(actual_path).startswith(str(base_dir)):
|
||||
raise HTTPException(status_code=403, detail="Access denied: path traversal detected")
|
||||
except (ValueError, RuntimeError):
|
||||
raise HTTPException(status_code=400, detail="Invalid path")
|
||||
|
||||
return actual_path
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/artifacts/{path:path}")
|
||||
async def get_artifact(thread_id: str, path: str) -> FileResponse:
|
||||
"""Get an artifact file by its path.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
path: The artifact path with virtual prefix (e.g., mnt/user-data/outputs/file.txt).
|
||||
|
||||
Returns:
|
||||
The file content as a FileResponse.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if file not found, 403 if access denied.
|
||||
"""
|
||||
actual_path = _resolve_artifact_path(thread_id, path)
|
||||
|
||||
if not actual_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Artifact not found: {path}")
|
||||
|
||||
if not actual_path.is_file():
|
||||
raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
|
||||
|
||||
return FileResponse(
|
||||
path=actual_path,
|
||||
filename=actual_path.name,
|
||||
)
|
||||
@@ -1,8 +1,19 @@
|
||||
from langchain.tools import tool
|
||||
from typing import Annotated
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from src.agents.thread_state import ThreadState
|
||||
|
||||
|
||||
@tool("present_files", parse_docstring=True)
|
||||
def present_file_tool(filepaths: list[str]) -> str:
|
||||
def present_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
filepaths: list[str],
|
||||
tool_call_id: Annotated[str, InjectedToolCallId],
|
||||
) -> Command:
|
||||
"""Make files visible to the user for viewing and rendering in the client interface.
|
||||
|
||||
When to use the present_files tool:
|
||||
@@ -22,4 +33,9 @@ def present_file_tool(filepaths: list[str]) -> str:
|
||||
Args:
|
||||
filepaths: List of absolute file paths to present to the user. **Only** files in `/mnt/user-data/outputs` can be presented.
|
||||
"""
|
||||
return "OK"
|
||||
existing_artifacts = runtime.state.get("artifacts") or []
|
||||
new_artifacts = existing_artifacts + filepaths
|
||||
runtime.state["artifacts"] = new_artifacts
|
||||
return Command(
|
||||
update={"artifacts": new_artifacts, "messages": [ToolMessage("Successfully presented files", tool_call_id=tool_call_id)]},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user