mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
feat(channels): upload file attachments via IM channels (Slack, Telegram, Feishu) (#1040)
This commit is contained in:
@@ -273,7 +273,7 @@ FEISHU_APP_SECRET=your_app_secret
|
||||
**Slack Setup**
|
||||
|
||||
1. Create a Slack App at [api.slack.com/apps](https://api.slack.com/apps) → Create New App → From scratch.
|
||||
2. Under **OAuth & Permissions**, add Bot Token Scopes: `app_mentions:read`, `chat:write`, `im:history`, `im:read`, `im:write`.
|
||||
2. Under **OAuth & Permissions**, add Bot Token Scopes: `app_mentions:read`, `chat:write`, `im:history`, `im:read`, `im:write`, `files:write`.
|
||||
3. Enable **Socket Mode** → generate an App-Level Token (`xapp-…`) with `connections:write` scope.
|
||||
4. Under **Event Subscriptions**, subscribe to bot events: `app_mention`, `message.im`.
|
||||
5. Set `SLACK_BOT_TOKEN` and `SLACK_APP_TOKEN` in `.env` and enable the channel in `config.yaml`.
|
||||
@@ -281,7 +281,7 @@ FEISHU_APP_SECRET=your_app_secret
|
||||
**Feishu / Lark Setup**
|
||||
|
||||
1. Create an app on [Feishu Open Platform](https://open.feishu.cn/) → enable **Bot** capability.
|
||||
2. Add permissions: `im:message`, `im:resource`.
|
||||
2. Add permissions: `im:message`, `im:message.p2p_msg:readonly`, `im:resource`.
|
||||
3. Under **Events**, subscribe to `im.message.receive_v1` and select **Long Connection** mode.
|
||||
4. Copy the App ID and App Secret. Set `FEISHU_APP_ID` and `FEISHU_APP_SECRET` in `.env` and enable the channel in `config.yaml`.
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
|
||||
from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +51,14 @@ class Channel(ABC):
|
||||
to route the reply to the correct conversation/thread.
|
||||
"""
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
"""Upload a single file attachment to the platform.
|
||||
|
||||
Returns True if the upload succeeded, False otherwise.
|
||||
Default implementation returns False (no file upload support).
|
||||
"""
|
||||
return False
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
def _make_inbound(
|
||||
@@ -80,9 +88,21 @@ class Channel(ABC):
|
||||
"""Outbound callback registered with the bus.
|
||||
|
||||
Only forwards messages targeted at this channel.
|
||||
Sends the text message first, then uploads any file attachments.
|
||||
File uploads are skipped entirely when the text send fails to avoid
|
||||
partial deliveries (files without accompanying text).
|
||||
"""
|
||||
if msg.channel_name == self.name:
|
||||
try:
|
||||
await self.send(msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to send outbound message on channel %s", self.name)
|
||||
return # Do not attempt file uploads when the text message failed
|
||||
|
||||
for attachment in msg.attachments:
|
||||
try:
|
||||
success = await self.send_file(msg, attachment)
|
||||
if not success:
|
||||
logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename)
|
||||
except Exception:
|
||||
logger.exception("[%s] failed to upload file %s", self.name, attachment.filename)
|
||||
|
||||
@@ -9,7 +9,7 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
from src.channels.base import Channel
|
||||
from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage
|
||||
from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -40,6 +40,10 @@ class FeishuChannel(Channel):
|
||||
self._CreateMessageReactionRequest = None
|
||||
self._CreateMessageReactionRequestBody = None
|
||||
self._Emoji = None
|
||||
self._CreateFileRequest = None
|
||||
self._CreateFileRequestBody = None
|
||||
self._CreateImageRequest = None
|
||||
self._CreateImageRequestBody = None
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -48,6 +52,10 @@ class FeishuChannel(Channel):
|
||||
try:
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateFileRequest,
|
||||
CreateFileRequestBody,
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
CreateMessageRequest,
|
||||
@@ -68,6 +76,10 @@ class FeishuChannel(Channel):
|
||||
self._CreateMessageReactionRequest = CreateMessageReactionRequest
|
||||
self._CreateMessageReactionRequestBody = CreateMessageReactionRequestBody
|
||||
self._Emoji = Emoji
|
||||
self._CreateFileRequest = CreateFileRequest
|
||||
self._CreateFileRequestBody = CreateFileRequestBody
|
||||
self._CreateImageRequest = CreateImageRequest
|
||||
self._CreateImageRequestBody = CreateImageRequestBody
|
||||
|
||||
app_id = self.config.get("app_id", "")
|
||||
app_secret = self.config.get("app_secret", "")
|
||||
@@ -184,6 +196,71 @@ class FeishuChannel(Channel):
|
||||
logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc)
|
||||
raise last_exc # type: ignore[misc]
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._api_client:
|
||||
return False
|
||||
|
||||
# Check size limits (image: 10MB, file: 30MB)
|
||||
if attachment.is_image and attachment.size > 10 * 1024 * 1024:
|
||||
logger.warning("[Feishu] image too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
|
||||
return False
|
||||
if not attachment.is_image and attachment.size > 30 * 1024 * 1024:
|
||||
logger.warning("[Feishu] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
|
||||
return False
|
||||
|
||||
try:
|
||||
if attachment.is_image:
|
||||
file_key = await self._upload_image(attachment.actual_path)
|
||||
msg_type = "image"
|
||||
content = json.dumps({"image_key": file_key})
|
||||
else:
|
||||
file_key = await self._upload_file(attachment.actual_path, attachment.filename)
|
||||
msg_type = "file"
|
||||
content = json.dumps({"file_key": file_key})
|
||||
|
||||
if msg.thread_ts:
|
||||
request = self._ReplyMessageRequest.builder().message_id(msg.thread_ts).request_body(self._ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).reply_in_thread(True).build()).build()
|
||||
await asyncio.to_thread(self._api_client.im.v1.message.reply, request)
|
||||
else:
|
||||
request = self._CreateMessageRequest.builder().receive_id_type("chat_id").request_body(self._CreateMessageRequestBody.builder().receive_id(msg.chat_id).msg_type(msg_type).content(content).build()).build()
|
||||
await asyncio.to_thread(self._api_client.im.v1.message.create, request)
|
||||
|
||||
logger.info("[Feishu] file sent: %s (type=%s)", attachment.filename, msg_type)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("[Feishu] failed to upload/send file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def _upload_image(self, path) -> str:
|
||||
"""Upload an image to Feishu and return the image_key."""
|
||||
with open(str(path), "rb") as f:
|
||||
request = self._CreateImageRequest.builder().request_body(self._CreateImageRequestBody.builder().image_type("message").image(f).build()).build()
|
||||
response = await asyncio.to_thread(self._api_client.im.v1.image.create, request)
|
||||
if not response.success():
|
||||
raise RuntimeError(f"Feishu image upload failed: code={response.code}, msg={response.msg}")
|
||||
return response.data.image_key
|
||||
|
||||
async def _upload_file(self, path, filename: str) -> str:
|
||||
"""Upload a file to Feishu and return the file_key."""
|
||||
suffix = path.suffix.lower() if hasattr(path, "suffix") else ""
|
||||
if suffix in (".xls", ".xlsx", ".csv"):
|
||||
file_type = "xls"
|
||||
elif suffix in (".ppt", ".pptx"):
|
||||
file_type = "ppt"
|
||||
elif suffix == ".pdf":
|
||||
file_type = "pdf"
|
||||
elif suffix in (".doc", ".docx"):
|
||||
file_type = "doc"
|
||||
else:
|
||||
file_type = "stream"
|
||||
|
||||
with open(str(path), "rb") as f:
|
||||
request = self._CreateFileRequest.builder().request_body(self._CreateFileRequestBody.builder().file_type(file_type).file_name(filename).file(f).build()).build()
|
||||
response = await asyncio.to_thread(self._api_client.im.v1.file.create, request)
|
||||
if not response.success():
|
||||
raise RuntimeError(f"Feishu file upload failed: code={response.code}, msg={response.msg}")
|
||||
return response.data.file_key
|
||||
|
||||
# -- message formatting ------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -4,10 +4,11 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
|
||||
from src.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from src.channels.store import ChannelStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -54,13 +55,18 @@ def _extract_response_text(result: dict | list) -> str:
|
||||
else:
|
||||
return ""
|
||||
|
||||
# Walk backwards to find usable response text
|
||||
# Walk backwards to find usable response text, but stop at the last
|
||||
# human message to avoid returning text from a previous turn.
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
msg_type = msg.get("type")
|
||||
|
||||
# Stop at the last human message — anything before it is a previous turn
|
||||
if msg_type == "human":
|
||||
break
|
||||
|
||||
# Check for tool messages from ask_clarification (interrupt case)
|
||||
if msg_type == "tool" and msg.get("name") == "ask_clarification":
|
||||
content = msg.get("content", "")
|
||||
@@ -129,6 +135,56 @@ def _format_artifact_text(artifacts: list[str]) -> str:
|
||||
return "Created Files: 📎 " + "、".join(filenames)
|
||||
|
||||
|
||||
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
||||
|
||||
|
||||
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
||||
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
||||
|
||||
Only paths under ``/mnt/user-data/outputs/`` are accepted; any other
|
||||
virtual path is rejected with a warning to prevent exfiltrating uploads
|
||||
or workspace files via IM channels.
|
||||
|
||||
Skips artifacts that cannot be resolved (missing files, invalid paths)
|
||||
and logs warnings for them.
|
||||
"""
|
||||
from src.config.paths import get_paths
|
||||
|
||||
attachments: list[ResolvedAttachment] = []
|
||||
paths = get_paths()
|
||||
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
|
||||
for virtual_path in artifacts:
|
||||
# Security: only allow files from the agent outputs directory
|
||||
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
|
||||
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
|
||||
continue
|
||||
try:
|
||||
actual = paths.resolve_virtual_path(thread_id, virtual_path)
|
||||
# Verify the resolved path is actually under the outputs directory
|
||||
# (guards against path-traversal even after prefix check)
|
||||
try:
|
||||
actual.resolve().relative_to(outputs_dir)
|
||||
except ValueError:
|
||||
logger.warning("[Manager] artifact path escapes outputs dir: %s -> %s", virtual_path, actual)
|
||||
continue
|
||||
if not actual.is_file():
|
||||
logger.warning("[Manager] artifact not found on disk: %s -> %s", virtual_path, actual)
|
||||
continue
|
||||
mime, _ = mimetypes.guess_type(str(actual))
|
||||
mime = mime or "application/octet-stream"
|
||||
attachments.append(ResolvedAttachment(
|
||||
virtual_path=virtual_path,
|
||||
actual_path=actual,
|
||||
filename=actual.name,
|
||||
mime_type=mime,
|
||||
size=actual.stat().st_size,
|
||||
is_image=mime.startswith("image/"),
|
||||
))
|
||||
except (ValueError, OSError) as exc:
|
||||
logger.warning("[Manager] failed to resolve artifact %s: %s", virtual_path, exc)
|
||||
return attachments
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
|
||||
|
||||
@@ -326,16 +382,26 @@ class ChannelManager:
|
||||
len(artifacts),
|
||||
)
|
||||
|
||||
# Append artifact filenames when present
|
||||
# Resolve artifact virtual paths to actual files for channel upload
|
||||
attachments: list[ResolvedAttachment] = []
|
||||
if artifacts:
|
||||
artifact_text = _format_artifact_text(artifacts)
|
||||
if response_text:
|
||||
response_text = response_text + "\n\n" + artifact_text
|
||||
else:
|
||||
response_text = artifact_text
|
||||
attachments = _resolve_attachments(thread_id, artifacts)
|
||||
resolved_virtuals = {a.virtual_path for a in attachments}
|
||||
unresolved = [p for p in artifacts if p not in resolved_virtuals]
|
||||
if unresolved:
|
||||
artifact_text = _format_artifact_text(unresolved)
|
||||
response_text = (response_text + "\n\n" + artifact_text) if response_text else artifact_text
|
||||
# Always include resolved attachment filenames as a text fallback so
|
||||
# files remain discoverable even when the upload is skipped or fails.
|
||||
if attachments:
|
||||
resolved_text = _format_artifact_text([a.virtual_path for a in attachments])
|
||||
response_text = (response_text + "\n\n" + resolved_text) if response_text else resolved_text
|
||||
|
||||
if not response_text:
|
||||
response_text = "(No response from agent)"
|
||||
if attachments:
|
||||
response_text = _format_artifact_text([a.virtual_path for a in attachments])
|
||||
else:
|
||||
response_text = "(No response from agent)"
|
||||
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
@@ -343,6 +409,7 @@ class ChannelManager:
|
||||
thread_id=thread_id,
|
||||
text=response_text,
|
||||
artifacts=artifacts,
|
||||
attachments=attachments,
|
||||
thread_ts=msg.thread_ts,
|
||||
)
|
||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||
|
||||
@@ -8,6 +8,7 @@ import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -57,6 +58,27 @@ class InboundMessage:
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedAttachment:
|
||||
"""A file attachment resolved to a host filesystem path, ready for upload.
|
||||
|
||||
Attributes:
|
||||
virtual_path: Original virtual path (e.g. /mnt/user-data/outputs/report.pdf).
|
||||
actual_path: Resolved host filesystem path.
|
||||
filename: Basename of the file.
|
||||
mime_type: MIME type (e.g. "application/pdf").
|
||||
size: File size in bytes.
|
||||
is_image: True for image/* MIME types (platforms may handle images differently).
|
||||
"""
|
||||
|
||||
virtual_path: str
|
||||
actual_path: Path
|
||||
filename: str
|
||||
mime_type: str
|
||||
size: int
|
||||
is_image: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutboundMessage:
|
||||
"""A message from the agent dispatcher back to a channel.
|
||||
@@ -78,6 +100,7 @@ class OutboundMessage:
|
||||
thread_id: str
|
||||
text: str
|
||||
artifacts: list[str] = field(default_factory=list)
|
||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||
is_final: bool = True
|
||||
thread_ts: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any
|
||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||
|
||||
from src.channels.base import Channel
|
||||
from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage
|
||||
from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -128,6 +128,27 @@ class SlackChannel(Channel):
|
||||
pass
|
||||
raise last_exc # type: ignore[misc]
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._web_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
kwargs: dict[str, Any] = {
|
||||
"channel": msg.chat_id,
|
||||
"file": str(attachment.actual_path),
|
||||
"filename": attachment.filename,
|
||||
"title": attachment.filename,
|
||||
}
|
||||
if msg.thread_ts:
|
||||
kwargs["thread_ts"] = msg.thread_ts
|
||||
|
||||
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("[Slack] failed to upload file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
# -- internal ----------------------------------------------------------
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
|
||||
@@ -8,7 +8,7 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
from src.channels.base import Channel
|
||||
from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage
|
||||
from src.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -127,6 +127,48 @@ class TelegramChannel(Channel):
|
||||
logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc)
|
||||
raise last_exc # type: ignore[misc]
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._application:
|
||||
return False
|
||||
|
||||
try:
|
||||
chat_id = int(msg.chat_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.error("[Telegram] Invalid chat_id: %s", msg.chat_id)
|
||||
return False
|
||||
|
||||
# Telegram limits: 10MB for photos, 50MB for documents
|
||||
if attachment.size > 50 * 1024 * 1024:
|
||||
logger.warning("[Telegram] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
|
||||
return False
|
||||
|
||||
bot = self._application.bot
|
||||
reply_to = self._last_bot_message.get(msg.chat_id)
|
||||
|
||||
try:
|
||||
if attachment.is_image and attachment.size <= 10 * 1024 * 1024:
|
||||
with open(attachment.actual_path, "rb") as f:
|
||||
kwargs: dict[str, Any] = {"chat_id": chat_id, "photo": f}
|
||||
if reply_to:
|
||||
kwargs["reply_to_message_id"] = reply_to
|
||||
sent = await bot.send_photo(**kwargs)
|
||||
else:
|
||||
from telegram import InputFile
|
||||
|
||||
with open(attachment.actual_path, "rb") as f:
|
||||
input_file = InputFile(f, filename=attachment.filename)
|
||||
kwargs = {"chat_id": chat_id, "document": input_file}
|
||||
if reply_to:
|
||||
kwargs["reply_to_message_id"] = reply_to
|
||||
sent = await bot.send_document(**kwargs)
|
||||
|
||||
self._last_bot_message[msg.chat_id] = sent.message_id
|
||||
logger.info("[Telegram] file sent: %s to chat=%s", attachment.filename, msg.chat_id)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("[Telegram] failed to send file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
||||
|
||||
435
backend/tests/test_channel_file_attachments.py
Normal file
435
backend/tests/test_channel_file_attachments.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""Tests for channel file attachment support (ResolvedAttachment, resolution, send_file)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from src.channels.base import Channel
|
||||
from src.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ResolvedAttachment tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolvedAttachment:
|
||||
def test_basic_construction(self, tmp_path):
|
||||
f = tmp_path / "test.pdf"
|
||||
f.write_bytes(b"PDF content")
|
||||
|
||||
att = ResolvedAttachment(
|
||||
virtual_path="/mnt/user-data/outputs/test.pdf",
|
||||
actual_path=f,
|
||||
filename="test.pdf",
|
||||
mime_type="application/pdf",
|
||||
size=11,
|
||||
is_image=False,
|
||||
)
|
||||
assert att.filename == "test.pdf"
|
||||
assert att.is_image is False
|
||||
assert att.size == 11
|
||||
|
||||
def test_image_detection(self, tmp_path):
|
||||
f = tmp_path / "photo.png"
|
||||
f.write_bytes(b"\x89PNG")
|
||||
|
||||
att = ResolvedAttachment(
|
||||
virtual_path="/mnt/user-data/outputs/photo.png",
|
||||
actual_path=f,
|
||||
filename="photo.png",
|
||||
mime_type="image/png",
|
||||
size=4,
|
||||
is_image=True,
|
||||
)
|
||||
assert att.is_image is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OutboundMessage.attachments field tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOutboundMessageAttachments:
|
||||
def test_default_empty_attachments(self):
|
||||
msg = OutboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="c1",
|
||||
thread_id="t1",
|
||||
text="hello",
|
||||
)
|
||||
assert msg.attachments == []
|
||||
|
||||
def test_attachments_populated(self, tmp_path):
|
||||
f = tmp_path / "file.txt"
|
||||
f.write_text("content")
|
||||
|
||||
att = ResolvedAttachment(
|
||||
virtual_path="/mnt/user-data/outputs/file.txt",
|
||||
actual_path=f,
|
||||
filename="file.txt",
|
||||
mime_type="text/plain",
|
||||
size=7,
|
||||
is_image=False,
|
||||
)
|
||||
msg = OutboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="c1",
|
||||
thread_id="t1",
|
||||
text="hello",
|
||||
attachments=[att],
|
||||
)
|
||||
assert len(msg.attachments) == 1
|
||||
assert msg.attachments[0].filename == "file.txt"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_attachments tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveAttachments:
|
||||
def test_resolves_existing_file(self, tmp_path):
|
||||
"""Successfully resolves a virtual path to an existing file."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
# Create the directory structure: threads/{thread_id}/user-data/outputs/
|
||||
thread_id = "test-thread-123"
|
||||
outputs_dir = tmp_path / "threads" / thread_id / "user-data" / "outputs"
|
||||
outputs_dir.mkdir(parents=True)
|
||||
test_file = outputs_dir / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 fake content")
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.resolve_virtual_path.return_value = test_file
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments(thread_id, ["/mnt/user-data/outputs/report.pdf"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].filename == "report.pdf"
|
||||
assert result[0].mime_type == "application/pdf"
|
||||
assert result[0].is_image is False
|
||||
assert result[0].size == len(b"%PDF-1.4 fake content")
|
||||
|
||||
def test_resolves_image_file(self, tmp_path):
|
||||
"""Images are detected by MIME type."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
thread_id = "test-thread"
|
||||
outputs_dir = tmp_path / "threads" / thread_id / "user-data" / "outputs"
|
||||
outputs_dir.mkdir(parents=True)
|
||||
img = outputs_dir / "chart.png"
|
||||
img.write_bytes(b"\x89PNG fake image")
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.resolve_virtual_path.return_value = img
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments(thread_id, ["/mnt/user-data/outputs/chart.png"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].is_image is True
|
||||
assert result[0].mime_type == "image/png"
|
||||
|
||||
def test_skips_missing_file(self, tmp_path):
|
||||
"""Missing files are skipped with a warning."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
outputs_dir = tmp_path / "outputs"
|
||||
outputs_dir.mkdir()
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.resolve_virtual_path.return_value = outputs_dir / "nonexistent.txt"
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments("t1", ["/mnt/user-data/outputs/nonexistent.txt"])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_skips_invalid_path(self):
|
||||
"""Invalid paths (ValueError from resolve) are skipped."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.resolve_virtual_path.side_effect = ValueError("bad path")
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments("t1", ["/invalid/path"])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_rejects_uploads_path(self):
|
||||
"""Paths under /mnt/user-data/uploads/ are rejected (security)."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
mock_paths = MagicMock()
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments("t1", ["/mnt/user-data/uploads/secret.pdf"])
|
||||
|
||||
assert result == []
|
||||
mock_paths.resolve_virtual_path.assert_not_called()
|
||||
|
||||
def test_rejects_workspace_path(self):
|
||||
"""Paths under /mnt/user-data/workspace/ are rejected (security)."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
mock_paths = MagicMock()
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments("t1", ["/mnt/user-data/workspace/config.py"])
|
||||
|
||||
assert result == []
|
||||
mock_paths.resolve_virtual_path.assert_not_called()
|
||||
|
||||
def test_rejects_path_traversal_escape(self, tmp_path):
|
||||
"""Paths that escape the outputs directory after resolution are rejected."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
thread_id = "t1"
|
||||
outputs_dir = tmp_path / "threads" / thread_id / "user-data" / "outputs"
|
||||
outputs_dir.mkdir(parents=True)
|
||||
# Simulate a resolved path that escapes outside the outputs directory
|
||||
escaped_file = tmp_path / "threads" / thread_id / "user-data" / "uploads" / "stolen.txt"
|
||||
escaped_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
escaped_file.write_text("sensitive")
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.resolve_virtual_path.return_value = escaped_file
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments(thread_id, ["/mnt/user-data/outputs/../uploads/stolen.txt"])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_multiple_artifacts_partial_resolution(self, tmp_path):
|
||||
"""Mixed valid/invalid artifacts: only valid ones are returned."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
thread_id = "t1"
|
||||
outputs_dir = tmp_path / "outputs"
|
||||
outputs_dir.mkdir()
|
||||
good_file = outputs_dir / "data.csv"
|
||||
good_file.write_text("a,b,c")
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.sandbox_outputs_dir.return_value = outputs_dir
|
||||
|
||||
def resolve_side_effect(tid, vpath):
|
||||
if "data.csv" in vpath:
|
||||
return good_file
|
||||
return tmp_path / "missing.txt"
|
||||
|
||||
mock_paths.resolve_virtual_path.side_effect = resolve_side_effect
|
||||
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments(
|
||||
thread_id,
|
||||
["/mnt/user-data/outputs/data.csv", "/mnt/user-data/outputs/missing.txt"],
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].filename == "data.csv"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel base class _on_outbound with attachments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _DummyChannel(Channel):
|
||||
"""Concrete channel for testing the base class behavior."""
|
||||
|
||||
def __init__(self, bus):
|
||||
super().__init__(name="dummy", bus=bus, config={})
|
||||
self.sent_messages: list[OutboundMessage] = []
|
||||
self.sent_files: list[tuple[OutboundMessage, ResolvedAttachment]] = []
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
self.sent_messages.append(msg)
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
self.sent_files.append((msg, attachment))
|
||||
return True
|
||||
|
||||
|
||||
class TestBaseChannelOnOutbound:
|
||||
def test_send_file_called_for_each_attachment(self, tmp_path):
|
||||
"""_on_outbound sends text first, then uploads each attachment."""
|
||||
bus = MessageBus()
|
||||
ch = _DummyChannel(bus)
|
||||
|
||||
f1 = tmp_path / "a.txt"
|
||||
f1.write_text("aaa")
|
||||
f2 = tmp_path / "b.png"
|
||||
f2.write_bytes(b"\x89PNG")
|
||||
|
||||
att1 = ResolvedAttachment("/mnt/user-data/outputs/a.txt", f1, "a.txt", "text/plain", 3, False)
|
||||
att2 = ResolvedAttachment("/mnt/user-data/outputs/b.png", f2, "b.png", "image/png", 4, True)
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="dummy",
|
||||
chat_id="c1",
|
||||
thread_id="t1",
|
||||
text="Here are your files",
|
||||
attachments=[att1, att2],
|
||||
)
|
||||
|
||||
_run(ch._on_outbound(msg))
|
||||
|
||||
assert len(ch.sent_messages) == 1
|
||||
assert len(ch.sent_files) == 2
|
||||
assert ch.sent_files[0][1].filename == "a.txt"
|
||||
assert ch.sent_files[1][1].filename == "b.png"
|
||||
|
||||
def test_no_attachments_no_send_file(self):
|
||||
"""When there are no attachments, send_file is not called."""
|
||||
bus = MessageBus()
|
||||
ch = _DummyChannel(bus)
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="dummy",
|
||||
chat_id="c1",
|
||||
thread_id="t1",
|
||||
text="No files here",
|
||||
)
|
||||
|
||||
_run(ch._on_outbound(msg))
|
||||
|
||||
assert len(ch.sent_messages) == 1
|
||||
assert len(ch.sent_files) == 0
|
||||
|
||||
def test_send_file_failure_does_not_block_others(self, tmp_path):
|
||||
"""If one attachment upload fails, remaining attachments still get sent."""
|
||||
bus = MessageBus()
|
||||
ch = _DummyChannel(bus)
|
||||
|
||||
# Override send_file to fail on first call, succeed on second
|
||||
call_count = 0
|
||||
original_send_file = ch.send_file
|
||||
|
||||
async def flaky_send_file(msg, att):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("upload failed")
|
||||
return await original_send_file(msg, att)
|
||||
|
||||
ch.send_file = flaky_send_file # type: ignore
|
||||
|
||||
f1 = tmp_path / "fail.txt"
|
||||
f1.write_text("x")
|
||||
f2 = tmp_path / "ok.txt"
|
||||
f2.write_text("y")
|
||||
|
||||
att1 = ResolvedAttachment("/mnt/user-data/outputs/fail.txt", f1, "fail.txt", "text/plain", 1, False)
|
||||
att2 = ResolvedAttachment("/mnt/user-data/outputs/ok.txt", f2, "ok.txt", "text/plain", 1, False)
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="dummy",
|
||||
chat_id="c1",
|
||||
thread_id="t1",
|
||||
text="files",
|
||||
attachments=[att1, att2],
|
||||
)
|
||||
|
||||
_run(ch._on_outbound(msg))
|
||||
|
||||
# First upload failed, second succeeded
|
||||
assert len(ch.sent_files) == 1
|
||||
assert ch.sent_files[0][1].filename == "ok.txt"
|
||||
|
||||
def test_send_raises_skips_file_uploads(self, tmp_path):
|
||||
"""When send() raises, file uploads are skipped entirely."""
|
||||
bus = MessageBus()
|
||||
ch = _DummyChannel(bus)
|
||||
|
||||
async def failing_send(msg):
|
||||
raise RuntimeError("network error")
|
||||
|
||||
ch.send = failing_send # type: ignore
|
||||
|
||||
f = tmp_path / "a.pdf"
|
||||
f.write_bytes(b"%PDF")
|
||||
att = ResolvedAttachment("/mnt/user-data/outputs/a.pdf", f, "a.pdf", "application/pdf", 4, False)
|
||||
msg = OutboundMessage(
|
||||
channel_name="dummy",
|
||||
chat_id="c1",
|
||||
thread_id="t1",
|
||||
text="Here is the file",
|
||||
attachments=[att],
|
||||
)
|
||||
|
||||
_run(ch._on_outbound(msg))
|
||||
|
||||
# send() raised, so send_file should never be called
|
||||
assert len(ch.sent_files) == 0
|
||||
|
||||
def test_default_send_file_returns_false(self):
|
||||
"""The base Channel.send_file returns False by default."""
|
||||
|
||||
class MinimalChannel(Channel):
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
pass
|
||||
|
||||
async def send(self, msg):
|
||||
pass
|
||||
|
||||
bus = MessageBus()
|
||||
ch = MinimalChannel(name="minimal", bus=bus, config={})
|
||||
att = ResolvedAttachment("/x", Path("/x"), "x", "text/plain", 0, False)
|
||||
msg = OutboundMessage(channel_name="minimal", chat_id="c", thread_id="t", text="t")
|
||||
|
||||
result = _run(ch.send_file(msg, att))
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager artifact resolution integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestManagerArtifactResolution:
|
||||
def test_handle_chat_populates_attachments(self):
|
||||
"""Verify _resolve_attachments is importable and works with the manager module."""
|
||||
from src.channels.manager import _resolve_attachments
|
||||
|
||||
# Basic smoke test: empty artifacts returns empty list
|
||||
mock_paths = MagicMock()
|
||||
with patch("src.config.paths.get_paths", return_value=mock_paths):
|
||||
result = _resolve_attachments("t1", [])
|
||||
assert result == []
|
||||
|
||||
def test_format_artifact_text_for_unresolved(self):
|
||||
"""_format_artifact_text produces expected output."""
|
||||
from src.channels.manager import _format_artifact_text
|
||||
|
||||
assert "report.pdf" in _format_artifact_text(["/mnt/user-data/outputs/report.pdf"])
|
||||
result = _format_artifact_text(["/mnt/user-data/outputs/a.txt", "/mnt/user-data/outputs/b.txt"])
|
||||
assert "a.txt" in result
|
||||
assert "b.txt" in result
|
||||
@@ -350,6 +350,26 @@ class TestExtractResponseText:
|
||||
}
|
||||
assert _extract_response_text(result) == "Could you clarify?"
|
||||
|
||||
def test_does_not_leak_previous_turn_text(self):
|
||||
"""When current turn AI has no text (only tool calls), do not return previous turn's text."""
|
||||
from src.channels.manager import _extract_response_text
|
||||
|
||||
result = {
|
||||
"messages": [
|
||||
{"type": "human", "content": "hello"},
|
||||
{"type": "ai", "content": "Hi there!"},
|
||||
{"type": "human", "content": "export data"},
|
||||
{
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/data.csv"]}}],
|
||||
},
|
||||
{"type": "tool", "name": "present_files", "content": "ok"},
|
||||
]
|
||||
}
|
||||
# Should return "" (no text in current turn), NOT "Hi there!" from previous turn
|
||||
assert _extract_response_text(result) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelManager tests
|
||||
|
||||
Reference in New Issue
Block a user