mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-08 16:24:45 +08:00
feat: add resource upload support for RAG (#768)
* feat: add resource upload support for RAG - Backend: Added ingest_file method to Retriever and MilvusRetriever - Backend: Added /api/rag/upload endpoint - Frontend: Added RAGTab in settings for uploading resources - Frontend: Updated translations and settings registration * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review of src/rag/milvus.py --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set
|
||||
|
||||
@@ -768,6 +770,172 @@ class MilvusRetriever(Retriever):
|
||||
# Ignore errors during cleanup
|
||||
pass
|
||||
|
||||
def _sanitize_filename(self, filename: str, max_length: int = 200) -> str:
|
||||
"""Sanitize filename for safe use in doc_id and URI construction.
|
||||
|
||||
Args:
|
||||
filename: Original filename to sanitize.
|
||||
max_length: Maximum allowed length for the filename (default: 200).
|
||||
|
||||
Returns:
|
||||
Sanitized filename safe for storage and URI construction.
|
||||
"""
|
||||
# Extract basename to remove any path components
|
||||
sanitized = Path(filename).name
|
||||
|
||||
# Remove or replace problematic characters
|
||||
# Keep alphanumeric, dots, hyphens, underscores; replace others with underscore
|
||||
sanitized = re.sub(r"[^\w.\-]", "_", sanitized)
|
||||
|
||||
# Collapse multiple underscores
|
||||
sanitized = re.sub(r"_+", "_", sanitized)
|
||||
|
||||
# Remove leading/trailing underscores and dots
|
||||
sanitized = sanitized.strip("_.")
|
||||
|
||||
# Ensure we have a valid filename
|
||||
if not sanitized:
|
||||
sanitized = "unnamed_file"
|
||||
|
||||
# Truncate if too long, preserving extension
|
||||
if len(sanitized) > max_length:
|
||||
# Try to preserve extension
|
||||
parts = sanitized.rsplit(".", 1)
|
||||
if len(parts) == 2 and len(parts[1]) <= 10:
|
||||
ext = "." + parts[1]
|
||||
base = parts[0][: max_length - len(ext)]
|
||||
sanitized = base + ext
|
||||
else:
|
||||
sanitized = sanitized[:max_length]
|
||||
|
||||
return sanitized
|
||||
|
||||
def _check_duplicate_file(self, filename: str) -> bool:
|
||||
"""Check if a file with the same name has been uploaded before."""
|
||||
try:
|
||||
if self._is_milvus_lite():
|
||||
results = self.client.query(
|
||||
collection_name=self.collection_name,
|
||||
filter=f"file == '{filename}' and source == 'uploaded'",
|
||||
output_fields=[self.id_field],
|
||||
limit=1,
|
||||
)
|
||||
return len(results) > 0
|
||||
else:
|
||||
# For LangChain Milvus, perform a search with metadata filter
|
||||
docs = self.client.similarity_search(
|
||||
"",
|
||||
k=1,
|
||||
expr=f"file == '{filename}' and source == 'uploaded'",
|
||||
)
|
||||
return len(docs) > 0
|
||||
except Exception:
|
||||
# If check fails, allow upload to proceed
|
||||
return False
|
||||
|
||||
def ingest_file(self, file_content: bytes, filename: str, **kwargs) -> Resource:
|
||||
"""Ingest a file into the Milvus vector store for RAG retrieval.
|
||||
|
||||
This method processes an uploaded file, splits it into chunks if necessary,
|
||||
generates embeddings, and stores them in the configured Milvus collection.
|
||||
|
||||
Args:
|
||||
file_content: Raw bytes of the file to ingest. Must be valid UTF-8
|
||||
encoded text content (e.g., markdown or plain text files).
|
||||
filename: Original filename. Used for title extraction, metadata storage,
|
||||
and URI construction. The filename is sanitized to remove special
|
||||
characters and path separators before use.
|
||||
**kwargs: Reserved for future use. Currently unused but accepted for
|
||||
forward compatibility (e.g., custom metadata, chunking options).
|
||||
|
||||
Returns:
|
||||
Resource: Object containing:
|
||||
- uri: Milvus URI in format ``milvus://{collection}/{filename}``
|
||||
- title: Extracted from first markdown heading or derived from filename
|
||||
- description: "Uploaded file" or "Uploaded file (new version)"
|
||||
|
||||
Raises:
|
||||
ValueError: If file_content cannot be decoded as UTF-8 text. This typically
|
||||
occurs when attempting to upload binary files (images, PDFs, etc.)
|
||||
which are not supported.
|
||||
RuntimeError: If document chunk insertion fails due to embedding generation
|
||||
errors, Milvus connection issues, or storage failures.
|
||||
ConnectionError: If unable to establish connection to Milvus server.
|
||||
|
||||
Supported file types:
|
||||
- Markdown files (.md): Title extracted from first ``# heading``
|
||||
- Plain text files (.txt): Title derived from filename
|
||||
|
||||
Duplicate handling:
|
||||
Files with the same name can be uploaded multiple times. Each upload
|
||||
creates a new document with a unique ID (includes timestamp). The
|
||||
description field indicates if this is a new version of an existing
|
||||
file. Old versions are retained in storage.
|
||||
|
||||
Example:
|
||||
>>> retriever = MilvusRetriever()
|
||||
>>> with open("document.md", "rb") as f:
|
||||
... resource = retriever.ingest_file(f.read(), "document.md")
|
||||
>>> print(resource.uri)
|
||||
milvus://documents/document.md
|
||||
"""
|
||||
# Check connection
|
||||
if not self.client:
|
||||
self._connect()
|
||||
|
||||
# Sanitize filename to prevent issues with special characters and path traversal
|
||||
safe_filename = self._sanitize_filename(filename)
|
||||
if safe_filename != filename:
|
||||
logger.debug(
|
||||
"Filename sanitized: '%s' -> '%s'", filename, safe_filename
|
||||
)
|
||||
|
||||
# Decode content (only UTF-8 text files supported)
|
||||
try:
|
||||
content = file_content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Only UTF-8 encoded text files are supported (e.g., .md, .txt). "
|
||||
"Binary files such as images, PDFs, or Word documents cannot be processed."
|
||||
)
|
||||
|
||||
# Check for existing file with same name
|
||||
is_duplicate = self._check_duplicate_file(safe_filename)
|
||||
if is_duplicate:
|
||||
logger.info(
|
||||
"File '%s' was previously uploaded. Creating new version.", safe_filename
|
||||
)
|
||||
|
||||
# Generate unique doc_id using filename, content length, and timestamp
|
||||
# Timestamp ensures uniqueness even for identical re-uploads
|
||||
timestamp = int(time.time() * 1000) # millisecond precision
|
||||
content_hash = hashlib.md5(
|
||||
f"{safe_filename}_{len(content)}_{timestamp}".encode()
|
||||
).hexdigest()[:8]
|
||||
base_name = safe_filename.rsplit(".", 1)[0] if "." in safe_filename else safe_filename
|
||||
doc_id = f"uploaded_{base_name}_{content_hash}"
|
||||
|
||||
title = self._extract_title_from_markdown(content, safe_filename)
|
||||
chunks = self._split_content(content)
|
||||
|
||||
# Insert chunks
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_id = f"{doc_id}_chunk_{i}" if len(chunks) > 1 else doc_id
|
||||
self._insert_document_chunk(
|
||||
doc_id=chunk_id,
|
||||
content=chunk,
|
||||
title=title,
|
||||
url=f"milvus://{self.collection_name}/{safe_filename}",
|
||||
metadata={"source": "uploaded", "file": safe_filename, "timestamp": timestamp},
|
||||
)
|
||||
|
||||
description = "Uploaded file (new version)" if is_duplicate else "Uploaded file"
|
||||
return Resource(
|
||||
uri=f"milvus://{self.collection_name}/{safe_filename}",
|
||||
title=title,
|
||||
description=description,
|
||||
)
|
||||
|
||||
def __del__(self) -> None: # pragma: no cover - best-effort cleanup
|
||||
"""Best-effort cleanup when instance is garbage collected."""
|
||||
self.close()
|
||||
|
||||
@@ -79,3 +79,52 @@ class Retriever(abc.ABC):
|
||||
Query relevant documents from the resources.
|
||||
"""
|
||||
pass
|
||||
|
||||
def ingest_file(self, file_content: bytes, filename: str, **kwargs) -> Resource:
|
||||
"""
|
||||
Ingest a file into the RAG provider and register it as a :class:`Resource`.
|
||||
|
||||
This method is intended to be overridden by concrete retriever implementations.
|
||||
The default implementation always raises :class:`NotImplementedError`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_content:
|
||||
Raw bytes of the file to ingest. For text-based formats, implementations
|
||||
will typically assume UTF-8 encoding unless documented otherwise. Binary
|
||||
formats (such as PDF, images, or office documents) should be passed as
|
||||
their original bytes.
|
||||
filename:
|
||||
The original filename, including extension (e.g. ``"report.pdf"``). This
|
||||
can be used by implementations to infer the file type, MIME type, or to
|
||||
populate the resulting resource's title.
|
||||
**kwargs:
|
||||
Additional, implementation-specific options. Examples may include:
|
||||
|
||||
- Explicit MIME type or file type hints.
|
||||
- Additional metadata to associate with the resource.
|
||||
- Chunking, indexing, or preprocessing parameters.
|
||||
|
||||
Unsupported or invalid keyword arguments may result in an exception being
|
||||
raised by the concrete implementation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Resource
|
||||
A :class:`Resource` instance describing the ingested file, including its
|
||||
URI and title. The exact URI scheme and how the resource is stored are
|
||||
implementation-defined.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
Always raised by the base ``Retriever`` implementation. Concrete
|
||||
implementations should override this method to provide functionality.
|
||||
ValueError
|
||||
May be raised by implementations if the input bytes, filename, or
|
||||
provided options are invalid.
|
||||
RuntimeError
|
||||
May be raised by implementations to signal unexpected ingestion or
|
||||
storage failures (e.g. backend service errors).
|
||||
"""
|
||||
raise NotImplementedError("ingest_file is not implemented")
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
from typing import Annotated, Any, List, Optional, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi import FastAPI, HTTPException, Query, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
||||
@@ -901,6 +901,74 @@ async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
|
||||
return RAGResourcesResponse(resources=[])
|
||||
|
||||
|
||||
MAX_UPLOAD_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB
|
||||
ALLOWED_EXTENSIONS = {".md", ".txt"}
|
||||
|
||||
|
||||
def _sanitize_filename(filename: str) -> str:
|
||||
"""Sanitize filename to prevent path traversal attacks."""
|
||||
# Extract only the base filename, removing any path components
|
||||
basename = os.path.basename(filename)
|
||||
# Remove any null bytes or other dangerous characters
|
||||
sanitized = basename.replace("\x00", "").strip()
|
||||
# Ensure filename is not empty after sanitization
|
||||
if not sanitized or sanitized in (".", ".."):
|
||||
return "unnamed_file"
|
||||
return sanitized
|
||||
|
||||
|
||||
@app.post("/api/rag/upload", response_model=Resource)
|
||||
async def upload_rag_resource(file: UploadFile):
|
||||
# Validate filename exists
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="Filename is required for upload")
|
||||
|
||||
# Sanitize filename to prevent path traversal
|
||||
safe_filename = _sanitize_filename(file.filename)
|
||||
|
||||
# Validate file extension
|
||||
_, ext = os.path.splitext(safe_filename.lower())
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid file type. Only {', '.join(ALLOWED_EXTENSIONS)} files are allowed.",
|
||||
)
|
||||
|
||||
# Read content with size limit check
|
||||
content = await file.read()
|
||||
if len(content) == 0:
|
||||
raise HTTPException(status_code=400, detail="Cannot upload an empty file")
|
||||
if len(content) > MAX_UPLOAD_SIZE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE_BYTES // (1024 * 1024)} MB.",
|
||||
)
|
||||
|
||||
retriever = build_retriever()
|
||||
if not retriever:
|
||||
raise HTTPException(status_code=500, detail="RAG provider not configured")
|
||||
try:
|
||||
return retriever.ingest_file(content, safe_filename)
|
||||
except NotImplementedError:
|
||||
raise HTTPException(
|
||||
status_code=501, detail="Upload not supported by current RAG provider"
|
||||
)
|
||||
except ValueError as exc:
|
||||
# Invalid user input or unsupported file content; treat as a client error
|
||||
logger.warning("Invalid RAG resource upload: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid RAG resource. Please check the file and try again.",
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
# Internal error during ingestion; log and return a generic server error
|
||||
logger.exception("Runtime error while ingesting RAG resource: %s", exc)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to ingest RAG resource due to an internal error.",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/config", response_model=ConfigResponse)
|
||||
async def config():
|
||||
"""Get the config of the server."""
|
||||
|
||||
Reference in New Issue
Block a user