diff --git a/backend/docs/FILE_UPLOAD.md b/backend/docs/FILE_UPLOAD.md index 19a5ff7..b975e20 100644 --- a/backend/docs/FILE_UPLOAD.md +++ b/backend/docs/FILE_UPLOAD.md @@ -131,6 +131,11 @@ read_file(path="/mnt/user-data/uploads/document.md") - 实际存储:`backend/.deer-flow/threads/{thread_id}/user-data/uploads/document.pdf` - 前端访问:`/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/document.pdf`(HTTP URL) +上传流程采用“线程目录优先”策略: +- 先写入 `backend/.deer-flow/threads/{thread_id}/user-data/uploads/` 作为权威存储 +- 本地沙箱(`sandbox_id=local`)直接使用线程目录内容 +- 非本地沙箱会额外同步到 `/mnt/user-data/uploads/*`,确保运行时可见 + ## 测试示例 ### 使用 curl 测试 @@ -243,7 +248,8 @@ backend/.deer-flow/threads/ 1. 确认 UploadsMiddleware 已在 agent.py 中注册 2. 检查 thread_id 是否正确 -3. 确认文件确实已上传到正确的目录 +3. 确认文件确实已上传到 `backend/.deer-flow/threads/{thread_id}/user-data/uploads/` +4. 非本地沙箱场景下,确认上传接口没有报错(需要成功完成 sandbox 同步) ## 开发建议 diff --git a/backend/src/gateway/routers/uploads.py b/backend/src/gateway/routers/uploads.py index 36f4aad..cf2a724 100644 --- a/backend/src/gateway/routers/uploads.py +++ b/backend/src/gateway/routers/uploads.py @@ -94,6 +94,7 @@ async def upload_files( raise HTTPException(status_code=400, detail="No files provided") uploads_dir = get_uploads_dir(thread_id) + paths = get_paths() uploaded_files = [] sandbox_provider = get_sandbox_provider() @@ -107,18 +108,22 @@ async def upload_files( try: # Normalize filename to prevent path traversal safe_filename = Path(file.filename).name - if not safe_filename: + if not safe_filename or safe_filename in {".", ".."} or "/" in safe_filename or "\\" in safe_filename: logger.warning(f"Skipping file with unsafe filename: {file.filename!r}") continue - # Save the original file - file_path = uploads_dir / safe_filename content = await file.read() + file_path = uploads_dir / safe_filename + file_path.write_bytes(content) # Build relative path from backend root - relative_path = str(get_paths().sandbox_uploads_dir(thread_id) / safe_filename) + relative_path = str(paths.sandbox_uploads_dir(thread_id) / safe_filename) virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{safe_filename}" - sandbox.update_file(virtual_path, content) + + # Keep local sandbox source of truth in thread-scoped host storage. + # For non-local sandboxes, also sync to virtual path for runtime visibility. + if sandbox_id != "local": + sandbox.update_file(virtual_path, content) file_info = { "filename": safe_filename, @@ -135,10 +140,15 @@ async def upload_files( if file_ext in CONVERTIBLE_EXTENSIONS: md_path = await convert_file_to_markdown(file_path) if md_path: - md_relative_path = str(get_paths().sandbox_uploads_dir(thread_id) / md_path.name) + md_relative_path = str(paths.sandbox_uploads_dir(thread_id) / md_path.name) + md_virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{md_path.name}" + + if sandbox_id != "local": + sandbox.update_file(md_virtual_path, md_path.read_bytes()) + file_info["markdown_file"] = md_path.name file_info["markdown_path"] = md_relative_path - file_info["markdown_virtual_path"] = f"{VIRTUAL_PATH_PREFIX}/uploads/{md_path.name}" + file_info["markdown_virtual_path"] = md_virtual_path file_info["markdown_artifact_url"] = f"/api/threads/{thread_id}/artifacts/mnt/user-data/uploads/{md_path.name}" uploaded_files.append(file_info) diff --git a/backend/tests/test_uploads_router.py b/backend/tests/test_uploads_router.py new file mode 100644 index 0000000..a7ce875 --- /dev/null +++ b/backend/tests/test_uploads_router.py @@ -0,0 +1,100 @@ +import asyncio +from io import BytesIO +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import UploadFile + +from src.gateway.routers import uploads + + +def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.acquire.return_value = "local" + sandbox = MagicMock() + provider.get.return_value = sandbox + + with ( + patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + + file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) + result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + + assert result.success is True + assert len(result.files) == 1 + assert result.files[0]["filename"] == "notes.txt" + assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads" + + sandbox.update_file.assert_not_called() + + +def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.acquire.return_value = "aio-1" + sandbox = MagicMock() + provider.get.return_value = sandbox + + async def fake_convert(file_path: Path) -> Path: + md_path = file_path.with_suffix(".md") + md_path.write_text("converted", encoding="utf-8") + return md_path + + with ( + patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=fake_convert)), + ): + + file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes")) + result = asyncio.run(uploads.upload_files("thread-aio", files=[file])) + + assert result.success is True + assert len(result.files) == 1 + file_info = result.files[0] + assert file_info["filename"] == "report.pdf" + assert file_info["markdown_file"] == "report.md" + + assert (thread_uploads_dir / "report.pdf").read_bytes() == b"pdf-bytes" + assert (thread_uploads_dir / "report.md").read_text(encoding="utf-8") == "converted" + + sandbox.update_file.assert_any_call("/mnt/user-data/uploads/report.pdf", b"pdf-bytes") + sandbox.update_file.assert_any_call("/mnt/user-data/uploads/report.md", b"converted") + + +def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.acquire.return_value = "local" + sandbox = MagicMock() + provider.get.return_value = sandbox + + with ( + patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + # These filenames must be rejected outright + for bad_name in ["..", "."]: + file = UploadFile(filename=bad_name, file=BytesIO(b"data")) + result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + assert result.success is True + assert result.files == [], f"Expected no files for unsafe filename {bad_name!r}" + + # Path-traversal prefixes are stripped to the basename and accepted safely + file = UploadFile(filename="../etc/passwd", file=BytesIO(b"data")) + result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + assert result.success is True + assert len(result.files) == 1 + assert result.files[0]["filename"] == "passwd" + + # Only the safely normalised file should exist + assert [f.name for f in thread_uploads_dir.iterdir()] == ["passwd"] diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 078b9a6..cc801ec 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -4,6 +4,7 @@ import type { ThreadsClient } from "@langchain/langgraph-sdk/client"; import { useStream, type UseStream } from "@langchain/langgraph-sdk/react"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useCallback } from "react"; +import { toast } from "sonner"; import type { PromptInputMessage } from "@/components/ai-elements/prompt-input"; @@ -122,17 +123,31 @@ export function useSubmitThread({ return null; }); - const files = (await Promise.all(filePromises)).filter( + const conversionResults = await Promise.all(filePromises); + const files = conversionResults.filter( (file): file is File => file !== null, ); + const failedConversions = conversionResults.length - files.length; - if (files.length > 0 && threadId) { + if (failedConversions > 0) { + throw new Error( + `Failed to prepare ${failedConversions} attachment(s) for upload. Please retry.`, + ); + } + + if (!threadId) { + throw new Error("Thread is not ready for file upload."); + } + + if (files.length > 0) { await uploadFiles(threadId, files); } } catch (error) { console.error("Failed to upload files:", error); - // Continue with message submission even if upload fails - // You might want to show an error toast here + const errorMessage = + error instanceof Error ? error.message : "Failed to upload files."; + toast.error(errorMessage); + throw error; } }