mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-08 08:20:20 +08:00
feat: add resource upload support for RAG (#768)
* feat: add resource upload support for RAG - Backend: Added ingest_file method to Retriever and MilvusRetriever - Backend: Added /api/rag/upload endpoint - Frontend: Added RAGTab in settings for uploading resources - Frontend: Updated translations and settings registration * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review of src/rag/milvus.py --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -9,7 +9,7 @@ import os
|
||||
from typing import Annotated, Any, List, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi import FastAPI, HTTPException, Query, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
||||
@@ -901,6 +901,74 @@ async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
|
||||
return RAGResourcesResponse(resources=[])
|
||||
|
||||
|
||||
MAX_UPLOAD_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
ALLOWED_EXTENSIONS = {".md", ".txt"}
|
||||
|
||||
|
||||
def _sanitize_filename(filename: str) -> str:
|
||||
"""Sanitize filename to prevent path traversal attacks."""
|
||||
# Extract only the base filename, removing any path components
|
||||
basename = os.path.basename(filename)
|
||||
# Remove any null bytes or other dangerous characters
|
||||
sanitized = basename.replace("\x00", "").strip()
|
||||
# Ensure filename is not empty after sanitization
|
||||
if not sanitized or sanitized in (".", ".."):
|
||||
return "unnamed_file"
|
||||
return sanitized
|
||||
|
||||
|
||||
@app.post("/api/rag/upload", response_model=Resource)
|
||||
async def upload_rag_resource(file: UploadFile):
|
||||
# Validate filename exists
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required for upload")
|
||||
|
||||
# Sanitize filename to prevent path traversal
|
||||
safe_filename = _sanitize_filename(file.filename)
|
||||
|
||||
# Validate file extension
|
||||
_, ext = os.path.splitext(safe_filename.lower())
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Only {', '.join(ALLOWED_EXTENSIONS)} files are allowed.",
|
||||
)
|
||||
|
||||
# Read content with size limit check
|
||||
content = await file.read()
|
||||
if len(content) == 0:
|
||||
raise HTTPException(status_code=400, detail="Cannot upload an empty file")
|
||||
if len(content) > MAX_UPLOAD_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE_BYTES // (1024 * 1024)} MB.",
|
||||
)
|
||||
|
||||
retriever = build_retriever()
|
||||
if not retriever:
|
||||
raise HTTPException(status_code=500, detail="RAG provider not configured")
|
||||
try:
|
||||
return retriever.ingest_file(content, safe_filename)
|
||||
except NotImplementedError:
|
||||
raise HTTPException(
|
||||
status_code=501, detail="Upload not supported by current RAG provider"
|
||||
)
|
||||
except ValueError as exc:
|
||||
# Invalid user input or unsupported file content; treat as a client error
|
||||
logger.warning("Invalid RAG resource upload: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid RAG resource. Please check the file and try again.",
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
# Internal error during ingestion; log and return a generic server error
|
||||
logger.exception("Runtime error while ingesting RAG resource: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to ingest RAG resource due to an internal error.",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/config", response_model=ConfigResponse)
|
||||
async def config():
|
||||
"""Get the config of the server."""
|
||||
|
||||
Reference in New Issue
Block a user