diff --git a/src/rag/milvus.py b/src/rag/milvus.py index 57c34e1..4c9d86d 100644 --- a/src/rag/milvus.py +++ b/src/rag/milvus.py @@ -3,6 +3,8 @@ import hashlib import logging +import re +import time from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Sequence, Set @@ -768,6 +770,172 @@ class MilvusRetriever(Retriever): # Ignore errors during cleanup pass + def _sanitize_filename(self, filename: str, max_length: int = 200) -> str: + """Sanitize filename for safe use in doc_id and URI construction. + + Args: + filename: Original filename to sanitize. + max_length: Maximum allowed length for the filename (default: 200). + + Returns: + Sanitized filename safe for storage and URI construction. + """ + # Extract basename to remove any path components + sanitized = Path(filename).name + + # Remove or replace problematic characters + # Keep alphanumeric, dots, hyphens, underscores; replace others with underscore + sanitized = re.sub(r"[^\w.\-]", "_", sanitized) + + # Collapse multiple underscores + sanitized = re.sub(r"_+", "_", sanitized) + + # Remove leading/trailing underscores and dots + sanitized = sanitized.strip("_.") + + # Ensure we have a valid filename + if not sanitized: + sanitized = "unnamed_file" + + # Truncate if too long, preserving extension + if len(sanitized) > max_length: + # Try to preserve extension + parts = sanitized.rsplit(".", 1) + if len(parts) == 2 and len(parts[1]) <= 10: + ext = "." + parts[1] + base = parts[0][: max_length - len(ext)] + sanitized = base + ext + else: + sanitized = sanitized[:max_length] + + return sanitized + + def _check_duplicate_file(self, filename: str) -> bool: + """Check if a file with the same name has been uploaded before.""" + try: + if self._is_milvus_lite(): + results = self.client.query( + collection_name=self.collection_name, + filter=f"file == '{filename}' and source == 'uploaded'", + output_fields=[self.id_field], + limit=1, + ) + return len(results) > 0 + else: + # For LangChain Milvus, perform a search with metadata filter + docs = self.client.similarity_search( + "", + k=1, + expr=f"file == '{filename}' and source == 'uploaded'", + ) + return len(docs) > 0 + except Exception: + # If check fails, allow upload to proceed + return False + + def ingest_file(self, file_content: bytes, filename: str, **kwargs) -> Resource: + """Ingest a file into the Milvus vector store for RAG retrieval. + + This method processes an uploaded file, splits it into chunks if necessary, + generates embeddings, and stores them in the configured Milvus collection. + + Args: + file_content: Raw bytes of the file to ingest. Must be valid UTF-8 + encoded text content (e.g., markdown or plain text files). + filename: Original filename. Used for title extraction, metadata storage, + and URI construction. The filename is sanitized to remove special + characters and path separators before use. + **kwargs: Reserved for future use. Currently unused but accepted for + forward compatibility (e.g., custom metadata, chunking options). + + Returns: + Resource: Object containing: + - uri: Milvus URI in format ``milvus://{collection}/{filename}`` + - title: Extracted from first markdown heading or derived from filename + - description: "Uploaded file" or "Uploaded file (new version)" + + Raises: + ValueError: If file_content cannot be decoded as UTF-8 text. This typically + occurs when attempting to upload binary files (images, PDFs, etc.) + which are not supported. + RuntimeError: If document chunk insertion fails due to embedding generation + errors, Milvus connection issues, or storage failures. + ConnectionError: If unable to establish connection to Milvus server. + + Supported file types: + - Markdown files (.md): Title extracted from first ``# heading`` + - Plain text files (.txt): Title derived from filename + + Duplicate handling: + Files with the same name can be uploaded multiple times. Each upload + creates a new document with a unique ID (includes timestamp). The + description field indicates if this is a new version of an existing + file. Old versions are retained in storage. + + Example: + >>> retriever = MilvusRetriever() + >>> with open("document.md", "rb") as f: + ... resource = retriever.ingest_file(f.read(), "document.md") + >>> print(resource.uri) + milvus://documents/document.md + """ + # Check connection + if not self.client: + self._connect() + + # Sanitize filename to prevent issues with special characters and path traversal + safe_filename = self._sanitize_filename(filename) + if safe_filename != filename: + logger.debug( + "Filename sanitized: '%s' -> '%s'", filename, safe_filename + ) + + # Decode content (only UTF-8 text files supported) + try: + content = file_content.decode("utf-8") + except UnicodeDecodeError: + raise ValueError( + "Only UTF-8 encoded text files are supported (e.g., .md, .txt). " + "Binary files such as images, PDFs, or Word documents cannot be processed." + ) + + # Check for existing file with same name + is_duplicate = self._check_duplicate_file(safe_filename) + if is_duplicate: + logger.info( + "File '%s' was previously uploaded. Creating new version.", safe_filename + ) + + # Generate unique doc_id using filename, content length, and timestamp + # Timestamp ensures uniqueness even for identical re-uploads + timestamp = int(time.time() * 1000) # millisecond precision + content_hash = hashlib.md5( + f"{safe_filename}_{len(content)}_{timestamp}".encode() + ).hexdigest()[:8] + base_name = safe_filename.rsplit(".", 1)[0] if "." in safe_filename else safe_filename + doc_id = f"uploaded_{base_name}_{content_hash}" + + title = self._extract_title_from_markdown(content, safe_filename) + chunks = self._split_content(content) + + # Insert chunks + for i, chunk in enumerate(chunks): + chunk_id = f"{doc_id}_chunk_{i}" if len(chunks) > 1 else doc_id + self._insert_document_chunk( + doc_id=chunk_id, + content=chunk, + title=title, + url=f"milvus://{self.collection_name}/{safe_filename}", + metadata={"source": "uploaded", "file": safe_filename, "timestamp": timestamp}, + ) + + description = "Uploaded file (new version)" if is_duplicate else "Uploaded file" + return Resource( + uri=f"milvus://{self.collection_name}/{safe_filename}", + title=title, + description=description, + ) + def __del__(self) -> None: # pragma: no cover - best-effort cleanup """Best-effort cleanup when instance is garbage collected.""" self.close() diff --git a/src/rag/retriever.py b/src/rag/retriever.py index df349d9..799b983 100644 --- a/src/rag/retriever.py +++ b/src/rag/retriever.py @@ -79,3 +79,52 @@ class Retriever(abc.ABC): Query relevant documents from the resources. """ pass + + def ingest_file(self, file_content: bytes, filename: str, **kwargs) -> Resource: + """ + Ingest a file into the RAG provider and register it as a :class:`Resource`. + + This method is intended to be overridden by concrete retriever implementations. + The default implementation always raises :class:`NotImplementedError`. + + Parameters + ---------- + file_content: + Raw bytes of the file to ingest. For text-based formats, implementations + will typically assume UTF-8 encoding unless documented otherwise. Binary + formats (such as PDF, images, or office documents) should be passed as + their original bytes. + filename: + The original filename, including extension (e.g. ``"report.pdf"``). This + can be used by implementations to infer the file type, MIME type, or to + populate the resulting resource's title. + **kwargs: + Additional, implementation-specific options. Examples may include: + + - Explicit MIME type or file type hints. + - Additional metadata to associate with the resource. + - Chunking, indexing, or preprocessing parameters. + + Unsupported or invalid keyword arguments may result in an exception being + raised by the concrete implementation. + + Returns + ------- + Resource + A :class:`Resource` instance describing the ingested file, including its + URI and title. The exact URI scheme and how the resource is stored are + implementation-defined. + + Raises + ------ + NotImplementedError + Always raised by the base ``Retriever`` implementation. Concrete + implementations should override this method to provide functionality. + ValueError + May be raised by implementations if the input bytes, filename, or + provided options are invalid. + RuntimeError + May be raised by implementations to signal unexpected ingestion or + storage failures (e.g. backend service errors). + """ + raise NotImplementedError("ingest_file is not implemented") diff --git a/src/server/app.py b/src/server/app.py index fa9ff77..c3304b4 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -9,7 +9,7 @@ import os from typing import Annotated, Any, List, Optional, cast from uuid import uuid4 -from fastapi import FastAPI, HTTPException, Query +from fastapi import FastAPI, HTTPException, Query, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, StreamingResponse from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage @@ -901,6 +901,74 @@ async def rag_resources(request: Annotated[RAGResourceRequest, Query()]): return RAGResourcesResponse(resources=[]) +MAX_UPLOAD_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB +ALLOWED_EXTENSIONS = {".md", ".txt"} + + +def _sanitize_filename(filename: str) -> str: + """Sanitize filename to prevent path traversal attacks.""" + # Extract only the base filename, removing any path components + basename = os.path.basename(filename) + # Remove any null bytes or other dangerous characters + sanitized = basename.replace("\x00", "").strip() + # Ensure filename is not empty after sanitization + if not sanitized or sanitized in (".", ".."): + return "unnamed_file" + return sanitized + + +@app.post("/api/rag/upload", response_model=Resource) +async def upload_rag_resource(file: UploadFile): + # Validate filename exists + if not file.filename: + raise HTTPException(status_code=400, detail="Filename is required for upload") + + # Sanitize filename to prevent path traversal + safe_filename = _sanitize_filename(file.filename) + + # Validate file extension + _, ext = os.path.splitext(safe_filename.lower()) + if ext not in ALLOWED_EXTENSIONS: + raise HTTPException( + status_code=400, + detail=f"Invalid file type. Only {', '.join(ALLOWED_EXTENSIONS)} files are allowed.", + ) + + # Read content with size limit check + content = await file.read() + if len(content) == 0: + raise HTTPException(status_code=400, detail="Cannot upload an empty file") + if len(content) > MAX_UPLOAD_SIZE_BYTES: + raise HTTPException( + status_code=413, + detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE_BYTES // (1024 * 1024)} MB.", + ) + + retriever = build_retriever() + if not retriever: + raise HTTPException(status_code=500, detail="RAG provider not configured") + try: + return retriever.ingest_file(content, safe_filename) + except NotImplementedError: + raise HTTPException( + status_code=501, detail="Upload not supported by current RAG provider" + ) + except ValueError as exc: + # Invalid user input or unsupported file content; treat as a client error + logger.warning("Invalid RAG resource upload: %s", exc) + raise HTTPException( + status_code=400, + detail="Invalid RAG resource. Please check the file and try again.", + ) + except RuntimeError as exc: + # Internal error during ingestion; log and return a generic server error + logger.exception("Runtime error while ingesting RAG resource: %s", exc) + raise HTTPException( + status_code=500, + detail="Failed to ingest RAG resource due to an internal error.", + ) + + @app.get("/api/config", response_model=ConfigResponse) async def config(): """Get the config of the server.""" diff --git a/tests/unit/server/test_app.py b/tests/unit/server/test_app.py index 3a3588d..981c48b 100644 --- a/tests/unit/server/test_app.py +++ b/tests/unit/server/test_app.py @@ -463,6 +463,111 @@ class TestRAGEndpoints: assert response.status_code == 200 assert response.json()["resources"] == [] + @patch("src.server.app.build_retriever") + def test_upload_rag_resource_success(self, mock_build_retriever, client): + mock_retriever = MagicMock() + mock_retriever.ingest_file.return_value = { + "uri": "milvus://test/file.md", + "title": "Test File", + "description": "Uploaded file", + } + mock_build_retriever.return_value = mock_retriever + + files = {"file": ("test.md", b"# Test content", "text/markdown")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 200 + assert response.json()["title"] == "Test File" + assert response.json()["uri"] == "milvus://test/file.md" + mock_retriever.ingest_file.assert_called_once() + + @patch("src.server.app.build_retriever") + def test_upload_rag_resource_no_retriever(self, mock_build_retriever, client): + mock_build_retriever.return_value = None + + files = {"file": ("test.md", b"# Test content", "text/markdown")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 500 + assert "RAG provider not configured" in response.json()["detail"] + + @patch("src.server.app.build_retriever") + def test_upload_rag_resource_not_implemented(self, mock_build_retriever, client): + mock_retriever = MagicMock() + mock_retriever.ingest_file.side_effect = NotImplementedError + mock_build_retriever.return_value = mock_retriever + + files = {"file": ("test.md", b"# Test content", "text/markdown")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 501 + assert "Upload not supported" in response.json()["detail"] + + @patch("src.server.app.build_retriever") + def test_upload_rag_resource_value_error(self, mock_build_retriever, client): + mock_retriever = MagicMock() + mock_retriever.ingest_file.side_effect = ValueError("File is not valid UTF-8") + mock_build_retriever.return_value = mock_retriever + + files = {"file": ("test.txt", b"\x80\x81\x82", "text/plain")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 400 + assert "Invalid RAG resource" in response.json()["detail"] + + @patch("src.server.app.build_retriever") + def test_upload_rag_resource_runtime_error(self, mock_build_retriever, client): + mock_retriever = MagicMock() + mock_retriever.ingest_file.side_effect = RuntimeError("Failed to insert into Milvus") + mock_build_retriever.return_value = mock_retriever + + files = {"file": ("test.md", b"# Test content", "text/markdown")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 500 + assert "Failed to ingest RAG resource" in response.json()["detail"] + + def test_upload_rag_resource_invalid_file_type(self, client): + files = {"file": ("test.exe", b"binary content", "application/octet-stream")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 400 + assert "Invalid file type" in response.json()["detail"] + + def test_upload_rag_resource_empty_file(self, client): + files = {"file": ("test.md", b"", "text/markdown")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 400 + assert "empty file" in response.json()["detail"] + + @patch("src.server.app.MAX_UPLOAD_SIZE_BYTES", 10) + def test_upload_rag_resource_file_too_large(self, client): + files = {"file": ("test.md", b"x" * 100, "text/markdown")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 413 + assert "File too large" in response.json()["detail"] + + @patch("src.server.app.build_retriever") + def test_upload_rag_resource_path_traversal_sanitized(self, mock_build_retriever, client): + mock_retriever = MagicMock() + mock_retriever.ingest_file.return_value = { + "uri": "milvus://test/file.md", + "title": "Test File", + "description": "Uploaded file", + } + mock_build_retriever.return_value = mock_retriever + + files = {"file": ("../../../etc/passwd.md", b"# Test", "text/markdown")} + response = client.post("/api/rag/upload", files=files) + + assert response.status_code == 200 + # Verify the filename was sanitized (only basename used) + mock_retriever.ingest_file.assert_called_once() + call_args = mock_retriever.ingest_file.call_args + assert call_args[0][1] == "passwd.md" + class TestChatStreamEndpoint: @patch("src.server.app.graph") diff --git a/web/messages/en.json b/web/messages/en.json index 3197edb..7cb1619 100644 --- a/web/messages/en.json +++ b/web/messages/en.json @@ -32,6 +32,17 @@ "addNewMCPServers": "Add New MCP Servers", "mcpConfigDescription": "DeerFlow uses the standard JSON MCP config to create a new server.", "pasteConfigBelow": "Paste your config below and click \"Add\" to add new servers.", + "rag": { + "title": "Resources", + "description": "Manage your knowledge base resources here. Upload markdown or text files to be indexed for retrieval.", + "upload": "Upload", + "uploading": "Uploading...", + "uploadSuccess": "File uploaded successfully", + "uploadFailed": "Failed to upload file", + "emptyFile": "Cannot upload an empty file", + "loading": "Loading resources...", + "noResources": "No resources found. Upload a file to get started." + }, "add": "Add", "general": { "title": "General", diff --git a/web/messages/zh.json b/web/messages/zh.json index f77cc24..d51dff3 100644 --- a/web/messages/zh.json +++ b/web/messages/zh.json @@ -32,6 +32,17 @@ "addNewMCPServers": "添加新的 MCP 服务器", "mcpConfigDescription": "DeerFlow 使用标准 JSON MCP 配置来创建新服务器。", "pasteConfigBelow": "将您的配置粘贴到下面,然后点击\"添加\"来添加新服务器。", + "rag": { + "title": "资源", + "description": "在此管理您的知识库资源。上传 Markdown 或文本文件以供检索索引。", + "upload": "上传", + "uploading": "上传中...", + "uploadSuccess": "文件上传成功", + "uploadFailed": "文件上传失败", + "emptyFile": "无法上传空文件", + "loading": "正在加载资源...", + "noResources": "未找到资源。上传文件以开始使用。" + }, "add": "添加", "general": { "title": "通用", diff --git a/web/src/app/settings/tabs/index.tsx b/web/src/app/settings/tabs/index.tsx index a98a137..61d0cd1 100644 --- a/web/src/app/settings/tabs/index.tsx +++ b/web/src/app/settings/tabs/index.tsx @@ -6,8 +6,9 @@ import { Settings, type LucideIcon } from "lucide-react"; import { AboutTab } from "./about-tab"; import { GeneralTab } from "./general-tab"; import { MCPTab } from "./mcp-tab"; +import { RAGTab } from "./rag-tab"; -export const SETTINGS_TABS = [GeneralTab, MCPTab, AboutTab].map((tab) => { +export const SETTINGS_TABS = [GeneralTab, RAGTab, MCPTab, AboutTab].map((tab) => { const name = tab.displayName ?? tab.name; return { ...tab, diff --git a/web/src/app/settings/tabs/rag-tab.tsx b/web/src/app/settings/tabs/rag-tab.tsx new file mode 100644 index 0000000..22b4ac0 --- /dev/null +++ b/web/src/app/settings/tabs/rag-tab.tsx @@ -0,0 +1,151 @@ +// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +// SPDX-License-Identifier: MIT + +import { Database, FileText, Upload } from "lucide-react"; +import { useTranslations } from "next-intl"; +import { useCallback, useEffect, useRef, useState } from "react"; +import { toast } from "sonner"; + +import { Button } from "~/components/ui/button"; +import { resolveServiceURL } from "~/core/api/resolve-service-url"; +import type { Resource } from "~/core/messages"; +import { cn } from "~/lib/utils"; + +import type { Tab } from "./types"; + +export const RAGTab: Tab = () => { + const t = useTranslations("settings.rag"); + const [resources, setResources] = useState([]); + const [uploading, setUploading] = useState(false); + const [loading, setLoading] = useState(false); + const fileInputRef = useRef(null); + + const fetchResources = useCallback(async () => { + setLoading(true); + try { + const response = await fetch(resolveServiceURL("rag/resources"), { + method: "GET", + }); + if (response.ok) { + const data = await response.json(); + setResources(data.resources ?? []); + } + } catch (error) { + console.error("Failed to fetch resources:", error); + } finally { + setLoading(false); + } + }, []); + + useEffect(() => { + void fetchResources(); + }, [fetchResources]); + + const handleUpload = async (event: React.ChangeEvent) => { + const file = event.target.files?.[0]; + if (!file) return; + + if (file.size === 0) { + toast.error(t("emptyFile")); + event.target.value = ""; + return; + } + + setUploading(true); + const formData = new FormData(); + formData.append("file", file); + + try { + const response = await fetch(resolveServiceURL("rag/upload"), { + method: "POST", + body: formData, + }); + + if (response.ok) { + toast.success(t("uploadSuccess")); + void fetchResources(); + } else { + const error = await response.json(); + toast.error(error.detail ?? t("uploadFailed")); + } + } catch (error) { + console.error("Upload error:", error); + toast.error(t("uploadFailed")); + } finally { + setUploading(false); + // Reset input value to allow uploading same file again + event.target.value = ""; + } + }; + + return ( +
+
+
+

{t("title")}

+
+ + +
+
+
{t("description")}
+
+
+ {loading ? ( +
+ {t("loading")} +
+ ) : resources.length === 0 ? ( +
+ +

{t("noResources")}

+
+ ) : ( +
    + {resources.map((resource, index) => ( +
  • +
    + +
    +
    +

    {resource.title}

    +
    + + {resource.uri} + + {resource.description && ( + <> + + {resource.description} + + )} +
    +
    +
  • + ))} +
+ )} +
+
+ ); +}; + +RAGTab.icon = Database; +RAGTab.displayName = "Resources"; diff --git a/web/src/core/messages/types.ts b/web/src/core/messages/types.ts index c4dd9ff..a1109b0 100644 --- a/web/src/core/messages/types.ts +++ b/web/src/core/messages/types.ts @@ -42,4 +42,5 @@ export interface ToolCallRuntime { export interface Resource { uri: string; title: string; + description?: string; }