diff --git a/README.md b/README.md index 300991c..54904c3 100644 --- a/README.md +++ b/README.md @@ -273,7 +273,7 @@ FEISHU_APP_SECRET=your_app_secret **Slack Setup** 1. Create a Slack App at [api.slack.com/apps](https://api.slack.com/apps) → Create New App → From scratch. -2. Under **OAuth & Permissions**, add Bot Token Scopes: `app_mentions:read`, `chat:write`, `im:history`, `im:read`, `im:write`. +2. Under **OAuth & Permissions**, add Bot Token Scopes: `app_mentions:read`, `chat:write`, `im:history`, `im:read`, `im:write`, `files:write`. 3. Enable **Socket Mode** → generate an App-Level Token (`xapp-…`) with `connections:write` scope. 4. Under **Event Subscriptions**, subscribe to bot events: `app_mention`, `message.im`. 5. Set `SLACK_BOT_TOKEN` and `SLACK_APP_TOKEN` in `.env` and enable the channel in `config.yaml`. @@ -281,7 +281,7 @@ FEISHU_APP_SECRET=your_app_secret **Feishu / Lark Setup** 1. Create an app on [Feishu Open Platform](https://open.feishu.cn/) → enable **Bot** capability. -2. Add permissions: `im:message`, `im:resource`. +2. Add permissions: `im:message`, `im:message.p2p_msg:readonly`, `im:resource`. 3. Under **Events**, subscribe to `im.message.receive_v1` and select **Long Connection** mode. 4. Copy the App ID and App Secret. Set `FEISHU_APP_ID` and `FEISHU_APP_SECRET` in `.env` and enable the channel in `config.yaml`. diff --git a/backend/src/channels/base.py b/backend/src/channels/base.py index 0d104cc..70111a9 100644 --- a/backend/src/channels/base.py +++ b/backend/src/channels/base.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from typing import Any -from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage +from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -51,6 +51,14 @@ class Channel(ABC): to route the reply to the correct conversation/thread. """ + async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: + """Upload a single file attachment to the platform. + + Returns True if the upload succeeded, False otherwise. + Default implementation returns False (no file upload support). + """ + return False + # -- helpers ----------------------------------------------------------- def _make_inbound( @@ -80,9 +88,21 @@ class Channel(ABC): """Outbound callback registered with the bus. Only forwards messages targeted at this channel. + Sends the text message first, then uploads any file attachments. + File uploads are skipped entirely when the text send fails to avoid + partial deliveries (files without accompanying text). """ if msg.channel_name == self.name: try: await self.send(msg) except Exception: logger.exception("Failed to send outbound message on channel %s", self.name) + return # Do not attempt file uploads when the text message failed + + for attachment in msg.attachments: + try: + success = await self.send_file(msg, attachment) + if not success: + logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename) + except Exception: + logger.exception("[%s] failed to upload file %s", self.name, attachment.filename) diff --git a/backend/src/channels/feishu.py b/backend/src/channels/feishu.py index efb86e8..585e312 100644 --- a/backend/src/channels/feishu.py +++ b/backend/src/channels/feishu.py @@ -9,7 +9,7 @@ import threading from typing import Any from src.channels.base import Channel -from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage +from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -40,6 +40,10 @@ class FeishuChannel(Channel): self._CreateMessageReactionRequest = None self._CreateMessageReactionRequestBody = None self._Emoji = None + self._CreateFileRequest = None + self._CreateFileRequestBody = None + self._CreateImageRequest = None + self._CreateImageRequestBody = None async def start(self) -> None: if self._running: @@ -48,6 +52,10 @@ class FeishuChannel(Channel): try: import lark_oapi as lark from lark_oapi.api.im.v1 import ( + CreateFileRequest, + CreateFileRequestBody, + CreateImageRequest, + CreateImageRequestBody, CreateMessageReactionRequest, CreateMessageReactionRequestBody, CreateMessageRequest, @@ -68,6 +76,10 @@ class FeishuChannel(Channel): self._CreateMessageReactionRequest = CreateMessageReactionRequest self._CreateMessageReactionRequestBody = CreateMessageReactionRequestBody self._Emoji = Emoji + self._CreateFileRequest = CreateFileRequest + self._CreateFileRequestBody = CreateFileRequestBody + self._CreateImageRequest = CreateImageRequest + self._CreateImageRequestBody = CreateImageRequestBody app_id = self.config.get("app_id", "") app_secret = self.config.get("app_secret", "") @@ -184,6 +196,71 @@ class FeishuChannel(Channel): logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc) raise last_exc # type: ignore[misc] + async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: + if not self._api_client: + return False + + # Check size limits (image: 10MB, file: 30MB) + if attachment.is_image and attachment.size > 10 * 1024 * 1024: + logger.warning("[Feishu] image too large (%d bytes), skipping: %s", attachment.size, attachment.filename) + return False + if not attachment.is_image and attachment.size > 30 * 1024 * 1024: + logger.warning("[Feishu] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename) + return False + + try: + if attachment.is_image: + file_key = await self._upload_image(attachment.actual_path) + msg_type = "image" + content = json.dumps({"image_key": file_key}) + else: + file_key = await self._upload_file(attachment.actual_path, attachment.filename) + msg_type = "file" + content = json.dumps({"file_key": file_key}) + + if msg.thread_ts: + request = self._ReplyMessageRequest.builder().message_id(msg.thread_ts).request_body(self._ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).reply_in_thread(True).build()).build() + await asyncio.to_thread(self._api_client.im.v1.message.reply, request) + else: + request = self._CreateMessageRequest.builder().receive_id_type("chat_id").request_body(self._CreateMessageRequestBody.builder().receive_id(msg.chat_id).msg_type(msg_type).content(content).build()).build() + await asyncio.to_thread(self._api_client.im.v1.message.create, request) + + logger.info("[Feishu] file sent: %s (type=%s)", attachment.filename, msg_type) + return True + except Exception: + logger.exception("[Feishu] failed to upload/send file: %s", attachment.filename) + return False + + async def _upload_image(self, path) -> str: + """Upload an image to Feishu and return the image_key.""" + with open(str(path), "rb") as f: + request = self._CreateImageRequest.builder().request_body(self._CreateImageRequestBody.builder().image_type("message").image(f).build()).build() + response = await asyncio.to_thread(self._api_client.im.v1.image.create, request) + if not response.success(): + raise RuntimeError(f"Feishu image upload failed: code={response.code}, msg={response.msg}") + return response.data.image_key + + async def _upload_file(self, path, filename: str) -> str: + """Upload a file to Feishu and return the file_key.""" + suffix = path.suffix.lower() if hasattr(path, "suffix") else "" + if suffix in (".xls", ".xlsx", ".csv"): + file_type = "xls" + elif suffix in (".ppt", ".pptx"): + file_type = "ppt" + elif suffix == ".pdf": + file_type = "pdf" + elif suffix in (".doc", ".docx"): + file_type = "doc" + else: + file_type = "stream" + + with open(str(path), "rb") as f: + request = self._CreateFileRequest.builder().request_body(self._CreateFileRequestBody.builder().file_type(file_type).file_name(filename).file(f).build()).build() + response = await asyncio.to_thread(self._api_client.im.v1.file.create, request) + if not response.success(): + raise RuntimeError(f"Feishu file upload failed: code={response.code}, msg={response.msg}") + return response.data.file_key + # -- message formatting ------------------------------------------------ @staticmethod diff --git a/backend/src/channels/manager.py b/backend/src/channels/manager.py index 93e64af..07a08ee 100644 --- a/backend/src/channels/manager.py +++ b/backend/src/channels/manager.py @@ -4,10 +4,11 @@ from __future__ import annotations import asyncio import logging +import mimetypes from collections.abc import Mapping from typing import Any -from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage +from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from src.channels.store import ChannelStore logger = logging.getLogger(__name__) @@ -54,13 +55,18 @@ def _extract_response_text(result: dict | list) -> str: else: return "" - # Walk backwards to find usable response text + # Walk backwards to find usable response text, but stop at the last + # human message to avoid returning text from a previous turn. for msg in reversed(messages): if not isinstance(msg, dict): continue msg_type = msg.get("type") + # Stop at the last human message — anything before it is a previous turn + if msg_type == "human": + break + # Check for tool messages from ask_clarification (interrupt case) if msg_type == "tool" and msg.get("name") == "ask_clarification": content = msg.get("content", "") @@ -129,6 +135,56 @@ def _format_artifact_text(artifacts: list[str]) -> str: return "Created Files: 📎 " + "、".join(filenames) +_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/" + + +def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]: + """Resolve virtual artifact paths to host filesystem paths with metadata. + + Only paths under ``/mnt/user-data/outputs/`` are accepted; any other + virtual path is rejected with a warning to prevent exfiltrating uploads + or workspace files via IM channels. + + Skips artifacts that cannot be resolved (missing files, invalid paths) + and logs warnings for them. + """ + from src.config.paths import get_paths + + attachments: list[ResolvedAttachment] = [] + paths = get_paths() + outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve() + for virtual_path in artifacts: + # Security: only allow files from the agent outputs directory + if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX): + logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path) + continue + try: + actual = paths.resolve_virtual_path(thread_id, virtual_path) + # Verify the resolved path is actually under the outputs directory + # (guards against path-traversal even after prefix check) + try: + actual.resolve().relative_to(outputs_dir) + except ValueError: + logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual) + continue + if not actual.is_file(): + logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual) + continue + mime, _ = mimetypes.guess_type(str(actual)) + mime = mime or "application/octet-stream" + attachments.append(ResolvedAttachment( + virtual_path=virtual_path, + actual_path=actual, + filename=actual.name, + mime_type=mime, + size=actual.stat().st_size, + is_image=mime.startswith("image/"), + )) + except (ValueError, OSError) as exc: + logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc) + return attachments + + class ChannelManager: """Core dispatcher that bridges IM channels to the DeerFlow agent. @@ -326,16 +382,26 @@ class ChannelManager: len(artifacts), ) - # Append artifact filenames when present + # Resolve artifact virtual paths to actual files for channel upload + attachments: list[ResolvedAttachment] = [] if artifacts: - artifact_text = _format_artifact_text(artifacts) - if response_text: - response_text = response_text + "\n\n" + artifact_text - else: - response_text = artifact_text + attachments = _resolve_attachments(thread_id, artifacts) + resolved_virtuals = {a.virtual_path for a in attachments} + unresolved = [p for p in artifacts if p not in resolved_virtuals] + if unresolved: + artifact_text = _format_artifact_text(unresolved) + response_text = (response_text + "\n\n" + artifact_text) if response_text else artifact_text + # Always include resolved attachment filenames as a text fallback so + # files remain discoverable even when the upload is skipped or fails. + if attachments: + resolved_text = _format_artifact_text([a.virtual_path for a in attachments]) + response_text = (response_text + "\n\n" + resolved_text) if response_text else resolved_text if not response_text: - response_text = "(No response from agent)" + if attachments: + response_text = _format_artifact_text([a.virtual_path for a in attachments]) + else: + response_text = "(No response from agent)" outbound = OutboundMessage( channel_name=msg.channel_name, @@ -343,6 +409,7 @@ class ChannelManager: thread_id=thread_id, text=response_text, artifacts=artifacts, + attachments=attachments, thread_ts=msg.thread_ts, ) logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id) diff --git a/backend/src/channels/message_bus.py b/backend/src/channels/message_bus.py index d252f42..4d0818a 100644 --- a/backend/src/channels/message_bus.py +++ b/backend/src/channels/message_bus.py @@ -8,6 +8,7 @@ import time from collections.abc import Callable, Coroutine from dataclasses import dataclass, field from enum import StrEnum +from pathlib import Path from typing import Any logger = logging.getLogger(__name__) @@ -57,6 +58,27 @@ class InboundMessage: created_at: float = field(default_factory=time.time) +@dataclass +class ResolvedAttachment: + """A file attachment resolved to a host filesystem path, ready for upload. + + Attributes: + virtual_path: Original virtual path (e.g. /mnt/user-data/outputs/report.pdf). + actual_path: Resolved host filesystem path. + filename: Basename of the file. + mime_type: MIME type (e.g. "application/pdf"). + size: File size in bytes. + is_image: True for image/* MIME types (platforms may handle images differently). + """ + + virtual_path: str + actual_path: Path + filename: str + mime_type: str + size: int + is_image: bool + + @dataclass class OutboundMessage: """A message from the agent dispatcher back to a channel. @@ -78,6 +100,7 @@ class OutboundMessage: thread_id: str text: str artifacts: list[str] = field(default_factory=list) + attachments: list[ResolvedAttachment] = field(default_factory=list) is_final: bool = True thread_ts: str | None = None metadata: dict[str, Any] = field(default_factory=dict) diff --git a/backend/src/channels/slack.py b/backend/src/channels/slack.py index 0cbe046..f28f2c1 100644 --- a/backend/src/channels/slack.py +++ b/backend/src/channels/slack.py @@ -9,7 +9,7 @@ from typing import Any from markdown_to_mrkdwn import SlackMarkdownConverter from src.channels.base import Channel -from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage +from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -128,6 +128,27 @@ class SlackChannel(Channel): pass raise last_exc # type: ignore[misc] + async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: + if not self._web_client: + return False + + try: + kwargs: dict[str, Any] = { + "channel": msg.chat_id, + "file": str(attachment.actual_path), + "filename": attachment.filename, + "title": attachment.filename, + } + if msg.thread_ts: + kwargs["thread_ts"] = msg.thread_ts + + await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs) + logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id) + return True + except Exception: + logger.exception("[Slack] failed to upload file: %s", attachment.filename) + return False + # -- internal ---------------------------------------------------------- def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None: diff --git a/backend/src/channels/telegram.py b/backend/src/channels/telegram.py index 14d168a..05d350c 100644 --- a/backend/src/channels/telegram.py +++ b/backend/src/channels/telegram.py @@ -8,7 +8,7 @@ import threading from typing import Any from src.channels.base import Channel -from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage +from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -127,6 +127,48 @@ class TelegramChannel(Channel): logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc) raise last_exc # type: ignore[misc] + async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: + if not self._application: + return False + + try: + chat_id = int(msg.chat_id) + except (ValueError, TypeError): + logger.error("[Telegram] Invalid chat_id: %s", msg.chat_id) + return False + + # Telegram limits: 10MB for photos, 50MB for documents + if attachment.size > 50 * 1024 * 1024: + logger.warning("[Telegram] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename) + return False + + bot = self._application.bot + reply_to = self._last_bot_message.get(msg.chat_id) + + try: + if attachment.is_image and attachment.size <= 10 * 1024 * 1024: + with open(attachment.actual_path, "rb") as f: + kwargs: dict[str, Any] = {"chat_id": chat_id, "photo": f} + if reply_to: + kwargs["reply_to_message_id"] = reply_to + sent = await bot.send_photo(**kwargs) + else: + from telegram import InputFile + + with open(attachment.actual_path, "rb") as f: + input_file = InputFile(f, filename=attachment.filename) + kwargs = {"chat_id": chat_id, "document": input_file} + if reply_to: + kwargs["reply_to_message_id"] = reply_to + sent = await bot.send_document(**kwargs) + + self._last_bot_message[msg.chat_id] = sent.message_id + logger.info("[Telegram] file sent: %s to chat=%s", attachment.filename, msg.chat_id) + return True + except Exception: + logger.exception("[Telegram] failed to send file: %s", attachment.filename) + return False + # -- helpers ----------------------------------------------------------- async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None: diff --git a/backend/tests/test_channel_file_attachments.py b/backend/tests/test_channel_file_attachments.py new file mode 100644 index 0000000..1d1164b --- /dev/null +++ b/backend/tests/test_channel_file_attachments.py @@ -0,0 +1,435 @@ +"""Tests for channel file attachment support (ResolvedAttachment, resolution, send_file).""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from unittest.mock import MagicMock, patch + +from src.channels.base import Channel +from src.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment + + +def _run(coro): + """Run an async coroutine synchronously.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +# --------------------------------------------------------------------------- +# ResolvedAttachment tests +# --------------------------------------------------------------------------- + + +class TestResolvedAttachment: + def test_basic_construction(self, tmp_path): + f = tmp_path / "test.pdf" + f.write_bytes(b"PDF content") + + att = ResolvedAttachment( + virtual_path="/mnt/user-data/outputs/test.pdf", + actual_path=f, + filename="test.pdf", + mime_type="application/pdf", + size=11, + is_image=False, + ) + assert att.filename == "test.pdf" + assert att.is_image is False + assert att.size == 11 + + def test_image_detection(self, tmp_path): + f = tmp_path / "photo.png" + f.write_bytes(b"\x89PNG") + + att = ResolvedAttachment( + virtual_path="/mnt/user-data/outputs/photo.png", + actual_path=f, + filename="photo.png", + mime_type="image/png", + size=4, + is_image=True, + ) + assert att.is_image is True + + +# --------------------------------------------------------------------------- +# OutboundMessage.attachments field tests +# --------------------------------------------------------------------------- + + +class TestOutboundMessageAttachments: + def test_default_empty_attachments(self): + msg = OutboundMessage( + channel_name="test", + chat_id="c1", + thread_id="t1", + text="hello", + ) + assert msg.attachments == [] + + def test_attachments_populated(self, tmp_path): + f = tmp_path / "file.txt" + f.write_text("content") + + att = ResolvedAttachment( + virtual_path="/mnt/user-data/outputs/file.txt", + actual_path=f, + filename="file.txt", + mime_type="text/plain", + size=7, + is_image=False, + ) + msg = OutboundMessage( + channel_name="test", + chat_id="c1", + thread_id="t1", + text="hello", + attachments=[att], + ) + assert len(msg.attachments) == 1 + assert msg.attachments[0].filename == "file.txt" + + +# --------------------------------------------------------------------------- +# _resolve_attachments tests +# --------------------------------------------------------------------------- + + +class TestResolveAttachments: + def test_resolves_existing_file(self, tmp_path): + """Successfully resolves a virtual path to an existing file.""" + from src.channels.manager import _resolve_attachments + + # Create the directory structure: threads/{thread_id}/user-data/outputs/ + thread_id = "test-thread-123" + outputs_dir = tmp_path / "threads" / thread_id / "user-data" / "outputs" + outputs_dir.mkdir(parents=True) + test_file = outputs_dir / "report.pdf" + test_file.write_bytes(b"%PDF-1.4 fake content") + + mock_paths = MagicMock() + mock_paths.resolve_virtual_path.return_value = test_file + mock_paths.sandbox_outputs_dir.return_value = outputs_dir + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments(thread_id, ["/mnt/user-data/outputs/report.pdf"]) + + assert len(result) == 1 + assert result[0].filename == "report.pdf" + assert result[0].mime_type == "application/pdf" + assert result[0].is_image is False + assert result[0].size == len(b"%PDF-1.4 fake content") + + def test_resolves_image_file(self, tmp_path): + """Images are detected by MIME type.""" + from src.channels.manager import _resolve_attachments + + thread_id = "test-thread" + outputs_dir = tmp_path / "threads" / thread_id / "user-data" / "outputs" + outputs_dir.mkdir(parents=True) + img = outputs_dir / "chart.png" + img.write_bytes(b"\x89PNG fake image") + + mock_paths = MagicMock() + mock_paths.resolve_virtual_path.return_value = img + mock_paths.sandbox_outputs_dir.return_value = outputs_dir + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments(thread_id, ["/mnt/user-data/outputs/chart.png"]) + + assert len(result) == 1 + assert result[0].is_image is True + assert result[0].mime_type == "image/png" + + def test_skips_missing_file(self, tmp_path): + """Missing files are skipped with a warning.""" + from src.channels.manager import _resolve_attachments + + outputs_dir = tmp_path / "outputs" + outputs_dir.mkdir() + + mock_paths = MagicMock() + mock_paths.resolve_virtual_path.return_value = outputs_dir / "nonexistent.txt" + mock_paths.sandbox_outputs_dir.return_value = outputs_dir + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments("t1", ["/mnt/user-data/outputs/nonexistent.txt"]) + + assert result == [] + + def test_skips_invalid_path(self): + """Invalid paths (ValueError from resolve) are skipped.""" + from src.channels.manager import _resolve_attachments + + mock_paths = MagicMock() + mock_paths.resolve_virtual_path.side_effect = ValueError("bad path") + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments("t1", ["/invalid/path"]) + + assert result == [] + + def test_rejects_uploads_path(self): + """Paths under /mnt/user-data/uploads/ are rejected (security).""" + from src.channels.manager import _resolve_attachments + + mock_paths = MagicMock() + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments("t1", ["/mnt/user-data/uploads/secret.pdf"]) + + assert result == [] + mock_paths.resolve_virtual_path.assert_not_called() + + def test_rejects_workspace_path(self): + """Paths under /mnt/user-data/workspace/ are rejected (security).""" + from src.channels.manager import _resolve_attachments + + mock_paths = MagicMock() + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments("t1", ["/mnt/user-data/workspace/config.py"]) + + assert result == [] + mock_paths.resolve_virtual_path.assert_not_called() + + def test_rejects_path_traversal_escape(self, tmp_path): + """Paths that escape the outputs directory after resolution are rejected.""" + from src.channels.manager import _resolve_attachments + + thread_id = "t1" + outputs_dir = tmp_path / "threads" / thread_id / "user-data" / "outputs" + outputs_dir.mkdir(parents=True) + # Simulate a resolved path that escapes outside the outputs directory + escaped_file = tmp_path / "threads" / thread_id / "user-data" / "uploads" / "stolen.txt" + escaped_file.parent.mkdir(parents=True, exist_ok=True) + escaped_file.write_text("sensitive") + + mock_paths = MagicMock() + mock_paths.resolve_virtual_path.return_value = escaped_file + mock_paths.sandbox_outputs_dir.return_value = outputs_dir + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments(thread_id, ["/mnt/user-data/outputs/../uploads/stolen.txt"]) + + assert result == [] + + def test_multiple_artifacts_partial_resolution(self, tmp_path): + """Mixed valid/invalid artifacts: only valid ones are returned.""" + from src.channels.manager import _resolve_attachments + + thread_id = "t1" + outputs_dir = tmp_path / "outputs" + outputs_dir.mkdir() + good_file = outputs_dir / "data.csv" + good_file.write_text("a,b,c") + + mock_paths = MagicMock() + mock_paths.sandbox_outputs_dir.return_value = outputs_dir + + def resolve_side_effect(tid, vpath): + if "data.csv" in vpath: + return good_file + return tmp_path / "missing.txt" + + mock_paths.resolve_virtual_path.side_effect = resolve_side_effect + + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments( + thread_id, + ["/mnt/user-data/outputs/data.csv", "/mnt/user-data/outputs/missing.txt"], + ) + + assert len(result) == 1 + assert result[0].filename == "data.csv" + + +# --------------------------------------------------------------------------- +# Channel base class _on_outbound with attachments +# --------------------------------------------------------------------------- + + +class _DummyChannel(Channel): + """Concrete channel for testing the base class behavior.""" + + def __init__(self, bus): + super().__init__(name="dummy", bus=bus, config={}) + self.sent_messages: list[OutboundMessage] = [] + self.sent_files: list[tuple[OutboundMessage, ResolvedAttachment]] = [] + + async def start(self): + pass + + async def stop(self): + pass + + async def send(self, msg: OutboundMessage) -> None: + self.sent_messages.append(msg) + + async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: + self.sent_files.append((msg, attachment)) + return True + + +class TestBaseChannelOnOutbound: + def test_send_file_called_for_each_attachment(self, tmp_path): + """_on_outbound sends text first, then uploads each attachment.""" + bus = MessageBus() + ch = _DummyChannel(bus) + + f1 = tmp_path / "a.txt" + f1.write_text("aaa") + f2 = tmp_path / "b.png" + f2.write_bytes(b"\x89PNG") + + att1 = ResolvedAttachment("/mnt/user-data/outputs/a.txt", f1, "a.txt", "text/plain", 3, False) + att2 = ResolvedAttachment("/mnt/user-data/outputs/b.png", f2, "b.png", "image/png", 4, True) + + msg = OutboundMessage( + channel_name="dummy", + chat_id="c1", + thread_id="t1", + text="Here are your files", + attachments=[att1, att2], + ) + + _run(ch._on_outbound(msg)) + + assert len(ch.sent_messages) == 1 + assert len(ch.sent_files) == 2 + assert ch.sent_files[0][1].filename == "a.txt" + assert ch.sent_files[1][1].filename == "b.png" + + def test_no_attachments_no_send_file(self): + """When there are no attachments, send_file is not called.""" + bus = MessageBus() + ch = _DummyChannel(bus) + + msg = OutboundMessage( + channel_name="dummy", + chat_id="c1", + thread_id="t1", + text="No files here", + ) + + _run(ch._on_outbound(msg)) + + assert len(ch.sent_messages) == 1 + assert len(ch.sent_files) == 0 + + def test_send_file_failure_does_not_block_others(self, tmp_path): + """If one attachment upload fails, remaining attachments still get sent.""" + bus = MessageBus() + ch = _DummyChannel(bus) + + # Override send_file to fail on first call, succeed on second + call_count = 0 + original_send_file = ch.send_file + + async def flaky_send_file(msg, att): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("upload failed") + return await original_send_file(msg, att) + + ch.send_file = flaky_send_file # type: ignore + + f1 = tmp_path / "fail.txt" + f1.write_text("x") + f2 = tmp_path / "ok.txt" + f2.write_text("y") + + att1 = ResolvedAttachment("/mnt/user-data/outputs/fail.txt", f1, "fail.txt", "text/plain", 1, False) + att2 = ResolvedAttachment("/mnt/user-data/outputs/ok.txt", f2, "ok.txt", "text/plain", 1, False) + + msg = OutboundMessage( + channel_name="dummy", + chat_id="c1", + thread_id="t1", + text="files", + attachments=[att1, att2], + ) + + _run(ch._on_outbound(msg)) + + # First upload failed, second succeeded + assert len(ch.sent_files) == 1 + assert ch.sent_files[0][1].filename == "ok.txt" + + def test_send_raises_skips_file_uploads(self, tmp_path): + """When send() raises, file uploads are skipped entirely.""" + bus = MessageBus() + ch = _DummyChannel(bus) + + async def failing_send(msg): + raise RuntimeError("network error") + + ch.send = failing_send # type: ignore + + f = tmp_path / "a.pdf" + f.write_bytes(b"%PDF") + att = ResolvedAttachment("/mnt/user-data/outputs/a.pdf", f, "a.pdf", "application/pdf", 4, False) + msg = OutboundMessage( + channel_name="dummy", + chat_id="c1", + thread_id="t1", + text="Here is the file", + attachments=[att], + ) + + _run(ch._on_outbound(msg)) + + # send() raised, so send_file should never be called + assert len(ch.sent_files) == 0 + + def test_default_send_file_returns_false(self): + """The base Channel.send_file returns False by default.""" + + class MinimalChannel(Channel): + async def start(self): + pass + + async def stop(self): + pass + + async def send(self, msg): + pass + + bus = MessageBus() + ch = MinimalChannel(name="minimal", bus=bus, config={}) + att = ResolvedAttachment("/x", Path("/x"), "x", "text/plain", 0, False) + msg = OutboundMessage(channel_name="minimal", chat_id="c", thread_id="t", text="t") + + result = _run(ch.send_file(msg, att)) + assert result is False + + +# --------------------------------------------------------------------------- +# ChannelManager artifact resolution integration +# --------------------------------------------------------------------------- + + +class TestManagerArtifactResolution: + def test_handle_chat_populates_attachments(self): + """Verify _resolve_attachments is importable and works with the manager module.""" + from src.channels.manager import _resolve_attachments + + # Basic smoke test: empty artifacts returns empty list + mock_paths = MagicMock() + with patch("src.config.paths.get_paths", return_value=mock_paths): + result = _resolve_attachments("t1", []) + assert result == [] + + def test_format_artifact_text_for_unresolved(self): + """_format_artifact_text produces expected output.""" + from src.channels.manager import _format_artifact_text + + assert "report.pdf" in _format_artifact_text(["/mnt/user-data/outputs/report.pdf"]) + result = _format_artifact_text(["/mnt/user-data/outputs/a.txt", "/mnt/user-data/outputs/b.txt"]) + assert "a.txt" in result + assert "b.txt" in result diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 131476d..04a3cf7 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -350,6 +350,26 @@ class TestExtractResponseText: } assert _extract_response_text(result) == "Could you clarify?" + def test_does_not_leak_previous_turn_text(self): + """When current turn AI has no text (only tool calls), do not return previous turn's text.""" + from src.channels.manager import _extract_response_text + + result = { + "messages": [ + {"type": "human", "content": "hello"}, + {"type": "ai", "content": "Hi there!"}, + {"type": "human", "content": "export data"}, + { + "type": "ai", + "content": "", + "tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/data.csv"]}}], + }, + {"type": "tool", "name": "present_files", "content": "ok"}, + ] + } + # Should return "" (no text in current turn), NOT "Hi there!" from previous turn + assert _extract_response_text(result) == "" + # --------------------------------------------------------------------------- # ChannelManager tests