feat: add resource upload support for RAG (#768)

* feat: add resource upload support for RAG

- Backend: Added ingest_file method to Retriever and MilvusRetriever
- Backend: Added /api/rag/upload endpoint
- Frontend: Added RAGTab in settings for uploading resources
- Frontend: Updated translations and settings registration

* Apply suggestions from code review

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply suggestions from code review

* Apply suggestions from code review of src/rag/milvus.py

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Willem Jiang
2025-12-19 09:55:34 +08:00
committed by GitHub
parent 3e8f2ce3ad
commit 04296cdf5a
9 changed files with 567 additions and 2 deletions

View File

@@ -9,7 +9,7 @@ import os
from typing import Annotated, Any, List, Optional, cast
from uuid import uuid4
from fastapi import FastAPI, HTTPException, Query
from fastapi import FastAPI, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
@@ -901,6 +901,74 @@ async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
return RAGResourcesResponse(resources=[])
MAX_UPLOAD_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB
ALLOWED_EXTENSIONS = {".md", ".txt"}
def _sanitize_filename(filename: str) -> str:
"""Sanitize filename to prevent path traversal attacks."""
# Extract only the base filename, removing any path components
basename = os.path.basename(filename)
# Remove any null bytes or other dangerous characters
sanitized = basename.replace("\x00", "").strip()
# Ensure filename is not empty after sanitization
if not sanitized or sanitized in (".", ".."):
return "unnamed_file"
return sanitized
@app.post("/api/rag/upload", response_model=Resource)
async def upload_rag_resource(file: UploadFile):
# Validate filename exists
if not file.filename:
raise HTTPException(status_code=400, detail="Filename is required for upload")
# Sanitize filename to prevent path traversal
safe_filename = _sanitize_filename(file.filename)
# Validate file extension
_, ext = os.path.splitext(safe_filename.lower())
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Invalid file type. Only {', '.join(ALLOWED_EXTENSIONS)} files are allowed.",
)
# Read content with size limit check
content = await file.read()
if len(content) == 0:
raise HTTPException(status_code=400, detail="Cannot upload an empty file")
if len(content) > MAX_UPLOAD_SIZE_BYTES:
raise HTTPException(
status_code=413,
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE_BYTES // (1024 * 1024)} MB.",
)
retriever = build_retriever()
if not retriever:
raise HTTPException(status_code=500, detail="RAG provider not configured")
try:
return retriever.ingest_file(content, safe_filename)
except NotImplementedError:
raise HTTPException(
status_code=501, detail="Upload not supported by current RAG provider"
)
except ValueError as exc:
# Invalid user input or unsupported file content; treat as a client error
logger.warning("Invalid RAG resource upload: %s", exc)
raise HTTPException(
status_code=400,
detail="Invalid RAG resource. Please check the file and try again.",
)
except RuntimeError as exc:
# Internal error during ingestion; log and return a generic server error
logger.exception("Runtime error while ingesting RAG resource: %s", exc)
raise HTTPException(
status_code=500,
detail="Failed to ingest RAG resource due to an internal error.",
)
@app.get("/api/config", response_model=ConfigResponse)
async def config():
"""Get the config of the server."""