diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index bd8617c..5d5e000 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -251,19 +251,22 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a **Architecture**: Channels communicate with the LangGraph Server through `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. **Components**: -- `message_bus.py` - Async pub/sub hub (`InboundMessage` -> queue -> dispatcher; `OutboundMessage` -> callbacks -> channels) -- `store.py` - JSON-file persistence mapping `channel_name:chat_id[:topic_id]` -> `thread_id` (keys are `channel:chat` for root conversations and `channel:chat:topic` for threaded conversations) -- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, sends messages via `client.runs.wait()`, routes commands +- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels) +- `store.py` - JSON-file persistence mapping `channel_name:chat_id[:topic_id]` → `thread_id` (keys are `channel:chat` for root conversations and `channel:chat:topic` for threaded conversations) +- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates - `base.py` - Abstract `Channel` base class (start/stop/send lifecycle) - `service.py` - Manages lifecycle of all configured channels from `config.yaml` -- `slack.py` / `feishu.py` / `telegram.py` - Platform-specific implementations +- `slack.py` / `feishu.py` / `telegram.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place) **Message Flow**: 1. External platform -> Channel impl -> `MessageBus.publish_inbound()` 2. `ChannelManager._dispatch_loop()` consumes from queue -3. For chat: look up/create thread on LangGraph Server -> `runs.wait()` -> extract response -> publish outbound -4. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API -5. Outbound -> channel callbacks -> platform reply +3. For chat: look up/create thread on LangGraph Server +4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`) +5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound +6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement) +7. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API +8. Outbound → channel callbacks → platform reply **Configuration** (`config.yaml` -> `channels`): - `langgraph_url` - LangGraph Server URL (default: `http://localhost:2024`) diff --git a/backend/README.md b/backend/README.md index 8ae02f3..df38e14 100644 --- a/backend/README.md +++ b/backend/README.md @@ -127,6 +127,12 @@ FastAPI application providing REST endpoints for frontend integration: | `GET /api/threads/{id}/uploads/list` | List uploaded files | | `GET /api/threads/{id}/artifacts/{path}` | Serve generated artifacts | +### IM Channels + +The IM bridge supports Feishu, Slack, and Telegram. Slack and Telegram still use the final `runs.wait()` response path, while Feishu now streams through `runs.stream(["messages-tuple", "values"])` and updates a single in-thread card in place. + +For Feishu card updates, DeerFlow stores the running card's `message_id` per inbound message and patches that same card until the run finishes, preserving the existing `OK` / `DONE` reaction flow. + --- ## Quick Start diff --git a/backend/src/channels/feishu.py b/backend/src/channels/feishu.py index 585e312..0e5b1a5 100644 --- a/backend/src/channels/feishu.py +++ b/backend/src/channels/feishu.py @@ -40,6 +40,11 @@ class FeishuChannel(Channel): self._CreateMessageReactionRequest = None self._CreateMessageReactionRequestBody = None self._Emoji = None + self._PatchMessageRequest = None + self._PatchMessageRequestBody = None + self._background_tasks: set[asyncio.Task] = set() + self._running_card_ids: dict[str, str] = {} + self._running_card_tasks: dict[str, asyncio.Task] = {} self._CreateFileRequest = None self._CreateFileRequestBody = None self._CreateImageRequest = None @@ -61,6 +66,8 @@ class FeishuChannel(Channel): CreateMessageRequest, CreateMessageRequestBody, Emoji, + PatchMessageRequest, + PatchMessageRequestBody, ReplyMessageRequest, ReplyMessageRequestBody, ) @@ -76,6 +83,8 @@ class FeishuChannel(Channel): self._CreateMessageReactionRequest = CreateMessageReactionRequest self._CreateMessageReactionRequestBody = CreateMessageReactionRequestBody self._Emoji = Emoji + self._PatchMessageRequest = PatchMessageRequest + self._PatchMessageRequestBody = PatchMessageRequestBody self._CreateFileRequest = CreateFileRequest self._CreateFileRequestBody = CreateFileRequestBody self._CreateImageRequest = CreateImageRequest @@ -145,6 +154,12 @@ class FeishuChannel(Channel): async def stop(self) -> None: self._running = False self.bus.unsubscribe_outbound(self._on_outbound) + for task in list(self._background_tasks): + task.cancel() + self._background_tasks.clear() + for task in list(self._running_card_tasks.values()): + task.cancel() + self._running_card_tasks.clear() if self._thread: self._thread.join(timeout=5) self._thread = None @@ -161,24 +176,11 @@ class FeishuChannel(Channel): msg.thread_ts, len(msg.text), ) - content = self._build_card_content(msg.text) last_exc: Exception | None = None for attempt in range(_max_retries): try: - if msg.thread_ts: - # Reply in thread (话题) - request = self._ReplyMessageRequest.builder().message_id(msg.thread_ts).request_body(self._ReplyMessageRequestBody.builder().msg_type("interactive").content(content).reply_in_thread(True).build()).build() - await asyncio.to_thread(self._api_client.im.v1.message.reply, request) - else: - # Send new message - request = self._CreateMessageRequest.builder().receive_id_type("chat_id").request_body(self._CreateMessageRequestBody.builder().receive_id(msg.chat_id).msg_type("interactive").content(content).build()).build() - await asyncio.to_thread(self._api_client.im.v1.message.create, request) - - # Add "DONE" reaction to the original message on final reply - if msg.is_final and msg.thread_ts: - await self._add_reaction(msg.thread_ts, "DONE") - + await self._send_card_message(msg) return # success except Exception as exc: last_exc = exc @@ -271,7 +273,7 @@ class FeishuChannel(Channel): headers, bold/italic, code blocks, lists, and links. """ card = { - "config": {"wide_screen_mode": True}, + "config": {"wide_screen_mode": True, "update_multi": True}, "elements": [{"tag": "markdown", "content": text}], } return json.dumps(card) @@ -289,18 +291,135 @@ class FeishuChannel(Channel): except Exception: logger.exception("[Feishu] failed to add reaction '%s' to message %s", emoji_type, message_id) - async def _send_running_reply(self, message_id: str) -> None: - """Reply to a message in-thread with a 'Working on it...' hint.""" + async def _reply_card(self, message_id: str, text: str) -> str | None: + """Reply with an interactive card and return the created card message ID.""" + if not self._api_client: + return None + + content = self._build_card_content(text) + request = self._ReplyMessageRequest.builder().message_id(message_id).request_body(self._ReplyMessageRequestBody.builder().msg_type("interactive").content(content).reply_in_thread(True).build()).build() + response = await asyncio.to_thread(self._api_client.im.v1.message.reply, request) + response_data = getattr(response, "data", None) + return getattr(response_data, "message_id", None) + + async def _create_card(self, chat_id: str, text: str) -> None: + """Create a new card message in the target chat.""" if not self._api_client: return + + content = self._build_card_content(text) + request = self._CreateMessageRequest.builder().receive_id_type("chat_id").request_body(self._CreateMessageRequestBody.builder().receive_id(chat_id).msg_type("interactive").content(content).build()).build() + await asyncio.to_thread(self._api_client.im.v1.message.create, request) + + async def _update_card(self, message_id: str, text: str) -> None: + """Patch an existing card message in place.""" + if not self._api_client or not self._PatchMessageRequest: + return + + content = self._build_card_content(text) + request = self._PatchMessageRequest.builder().message_id(message_id).request_body(self._PatchMessageRequestBody.builder().content(content).build()).build() + await asyncio.to_thread(self._api_client.im.v1.message.patch, request) + + def _track_background_task(self, task: asyncio.Task, *, name: str, msg_id: str) -> None: + """Keep a strong reference to fire-and-forget tasks and surface errors.""" + self._background_tasks.add(task) + task.add_done_callback(lambda done_task, task_name=name, mid=msg_id: self._finalize_background_task(done_task, task_name, mid)) + + def _finalize_background_task(self, task: asyncio.Task, name: str, msg_id: str) -> None: + self._background_tasks.discard(task) + self._log_task_error(task, name, msg_id) + + async def _create_running_card(self, source_message_id: str, text: str) -> str | None: + """Create the running card and cache its message ID when available.""" + running_card_id = await self._reply_card(source_message_id, text) + if running_card_id: + self._running_card_ids[source_message_id] = running_card_id + logger.info("[Feishu] running card created: source=%s card=%s", source_message_id, running_card_id) + else: + logger.warning("[Feishu] running card creation returned no message_id for source=%s, subsequent updates will fall back to new replies", source_message_id) + return running_card_id + + def _ensure_running_card_started(self, source_message_id: str, text: str = "Working on it...") -> asyncio.Task | None: + """Start running-card creation once per source message.""" + running_card_id = self._running_card_ids.get(source_message_id) + if running_card_id: + return None + + running_card_task = self._running_card_tasks.get(source_message_id) + if running_card_task: + return running_card_task + + running_card_task = asyncio.create_task(self._create_running_card(source_message_id, text)) + self._running_card_tasks[source_message_id] = running_card_task + running_card_task.add_done_callback(lambda done_task, mid=source_message_id: self._finalize_running_card_task(mid, done_task)) + return running_card_task + + def _finalize_running_card_task(self, source_message_id: str, task: asyncio.Task) -> None: + if self._running_card_tasks.get(source_message_id) is task: + self._running_card_tasks.pop(source_message_id, None) + self._log_task_error(task, "create_running_card", source_message_id) + + async def _ensure_running_card(self, source_message_id: str, text: str = "Working on it...") -> str | None: + """Ensure the in-thread running card exists and track its message ID.""" + running_card_id = self._running_card_ids.get(source_message_id) + if running_card_id: + return running_card_id + + running_card_task = self._ensure_running_card_started(source_message_id, text) + if running_card_task is None: + return self._running_card_ids.get(source_message_id) + return await running_card_task + + async def _send_running_reply(self, message_id: str) -> None: + """Reply to a message in-thread with a running card.""" try: - content = self._build_card_content("Working on it...") - request = self._ReplyMessageRequest.builder().message_id(message_id).request_body(self._ReplyMessageRequestBody.builder().msg_type("interactive").content(content).reply_in_thread(True).build()).build() - await asyncio.to_thread(self._api_client.im.v1.message.reply, request) - logger.info("[Feishu] 'Working on it......' reply sent for message %s", message_id) + await self._ensure_running_card(message_id) except Exception: logger.exception("[Feishu] failed to send running reply for message %s", message_id) + async def _send_card_message(self, msg: OutboundMessage) -> None: + """Send or update the Feishu card tied to the current request.""" + source_message_id = msg.thread_ts + if source_message_id: + running_card_id = self._running_card_ids.get(source_message_id) + awaited_running_card_task = False + + if not running_card_id: + running_card_task = self._running_card_tasks.get(source_message_id) + if running_card_task: + awaited_running_card_task = True + running_card_id = await running_card_task + + if running_card_id: + try: + await self._update_card(running_card_id, msg.text) + except Exception: + if not msg.is_final: + raise + logger.exception( + "[Feishu] failed to patch running card %s, falling back to final reply", + running_card_id, + ) + await self._reply_card(source_message_id, msg.text) + else: + logger.info("[Feishu] running card updated: source=%s card=%s", source_message_id, running_card_id) + elif msg.is_final: + await self._reply_card(source_message_id, msg.text) + elif awaited_running_card_task: + logger.warning( + "[Feishu] running card task finished without message_id for source=%s, skipping duplicate non-final creation", + source_message_id, + ) + else: + await self._ensure_running_card(source_message_id, msg.text) + + if msg.is_final: + self._running_card_ids.pop(source_message_id, None) + await self._add_reaction(source_message_id, "DONE") + return + + await self._create_card(msg.chat_id, msg.text) + # -- internal ---------------------------------------------------------- @staticmethod @@ -313,6 +432,25 @@ class FeishuChannel(Channel): except Exception: pass + @staticmethod + def _log_task_error(task: asyncio.Task, name: str, msg_id: str) -> None: + """Callback for background asyncio tasks to surface errors.""" + try: + exc = task.exception() + if exc: + logger.error("[Feishu] %s failed for msg_id=%s: %s", name, msg_id, exc) + except asyncio.CancelledError: + logger.info("[Feishu] %s cancelled for msg_id=%s", name, msg_id) + except Exception: + pass + + async def _prepare_inbound(self, msg_id: str, inbound) -> None: + """Kick off Feishu side effects without delaying inbound dispatch.""" + reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK")) + self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id) + self._ensure_running_card_started(msg_id) + await self.bus.publish_inbound(inbound) + def _on_message(self, event) -> None: """Called by lark-oapi when a message is received (runs in lark thread).""" try: @@ -364,14 +502,8 @@ class FeishuChannel(Channel): # Schedule on the async event loop if self._main_loop and self._main_loop.is_running(): logger.info("[Feishu] publishing inbound message to bus (type=%s, msg_id=%s)", msg_type.value, msg_id) - # Schedule all coroutines and attach error logging to futures - for name, coro in [ - ("add_reaction", self._add_reaction(msg_id, "OK")), - ("send_running_reply", self._send_running_reply(msg_id)), - ("publish_inbound", self.bus.publish_inbound(inbound)), - ]: - fut = asyncio.run_coroutine_threadsafe(coro, self._main_loop) - fut.add_done_callback(lambda f, n=name, mid=msg_id: self._log_future_error(f, n, mid)) + fut = asyncio.run_coroutine_threadsafe(self._prepare_inbound(msg_id, inbound), self._main_loop) + fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "prepare_inbound", mid)) else: logger.warning("[Feishu] main loop not running, cannot publish inbound message") except Exception: diff --git a/backend/src/channels/manager.py b/backend/src/channels/manager.py index a7e3840..614a091 100644 --- a/backend/src/channels/manager.py +++ b/backend/src/channels/manager.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio import logging import mimetypes +import time from collections.abc import Mapping from typing import Any @@ -23,6 +24,7 @@ DEFAULT_RUN_CONTEXT: dict[str, Any] = { "is_plan_mode": False, "subagent_enabled": False, } +STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35 def _as_dict(value: Any) -> dict[str, Any]: @@ -92,6 +94,98 @@ def _extract_response_text(result: dict | list) -> str: return "" +def _extract_text_content(content: Any) -> str: + """Extract text from a streaming payload content field.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, Mapping): + text = block.get("text") + if isinstance(text, str): + parts.append(text) + else: + nested = block.get("content") + if isinstance(nested, str): + parts.append(nested) + return "".join(parts) + if isinstance(content, Mapping): + for key in ("text", "content"): + value = content.get(key) + if isinstance(value, str): + return value + return "" + + +def _merge_stream_text(existing: str, chunk: str) -> str: + """Merge either delta text or cumulative text into a single snapshot.""" + if not chunk: + return existing + if not existing or chunk == existing: + return chunk or existing + if chunk.startswith(existing): + return chunk + if existing.endswith(chunk): + return existing + return existing + chunk + + +def _extract_stream_message_id(payload: Any, metadata: Any) -> str | None: + """Best-effort extraction of the streamed AI message identifier.""" + candidates = [payload, metadata] + if isinstance(payload, Mapping): + candidates.append(payload.get("kwargs")) + + for candidate in candidates: + if not isinstance(candidate, Mapping): + continue + for key in ("id", "message_id"): + value = candidate.get(key) + if isinstance(value, str) and value: + return value + return None + + +def _accumulate_stream_text( + buffers: dict[str, str], + current_message_id: str | None, + event_data: Any, +) -> tuple[str | None, str | None]: + """Convert a ``messages-tuple`` event into the latest displayable AI text.""" + payload = event_data + metadata: Any = None + if isinstance(event_data, (list, tuple)): + if event_data: + payload = event_data[0] + if len(event_data) > 1: + metadata = event_data[1] + + if isinstance(payload, str): + message_id = current_message_id or "__default__" + buffers[message_id] = _merge_stream_text(buffers.get(message_id, ""), payload) + return buffers[message_id], message_id + + if not isinstance(payload, Mapping): + return None, current_message_id + + payload_type = str(payload.get("type", "")).lower() + if "tool" in payload_type: + return None, current_message_id + + text = _extract_text_content(payload.get("content")) + if not text and isinstance(payload.get("kwargs"), Mapping): + text = _extract_text_content(payload["kwargs"].get("content")) + if not text: + return None, current_message_id + + message_id = _extract_stream_message_id(payload, metadata) or current_message_id or "__default__" + buffers[message_id] = _merge_stream_text(buffers.get(message_id, ""), text) + return buffers[message_id], message_id + + def _extract_artifacts(result: dict | list) -> list[str]: """Extract artifact paths from the last AI response cycle only. @@ -185,6 +279,33 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA return attachments +def _prepare_artifact_delivery( + thread_id: str, + response_text: str, + artifacts: list[str], +) -> tuple[str, list[ResolvedAttachment]]: + """Resolve attachments and append filename fallbacks to the text response.""" + attachments: list[ResolvedAttachment] = [] + if not artifacts: + return response_text, attachments + + attachments = _resolve_attachments(thread_id, artifacts) + resolved_virtuals = {attachment.virtual_path for attachment in attachments} + unresolved = [path for path in artifacts if path not in resolved_virtuals] + + if unresolved: + artifact_text = _format_artifact_text(unresolved) + response_text = (response_text + "\n\n" + artifact_text) if response_text else artifact_text + + # Always include resolved attachment filenames as a text fallback so files + # remain discoverable even when the upload is skipped or fails. + if attachments: + resolved_text = _format_artifact_text([attachment.virtual_path for attachment in attachments]) + response_text = (response_text + "\n\n" + resolved_text) if response_text else resolved_text + + return response_text, attachments + + class ChannelManager: """Core dispatcher that bridges IM channels to the DeerFlow agent. @@ -363,6 +484,17 @@ class ChannelManager: thread_id = await self._create_thread(client, msg) assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id) + if msg.channel_name == "feishu": + await self._handle_streaming_chat( + client, + msg, + thread_id, + assistant_id, + run_config, + run_context, + ) + return + logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) result = await client.runs.wait( thread_id, @@ -382,20 +514,7 @@ class ChannelManager: len(artifacts), ) - # Resolve artifact virtual paths to actual files for channel upload - attachments: list[ResolvedAttachment] = [] - if artifacts: - attachments = _resolve_attachments(thread_id, artifacts) - resolved_virtuals = {a.virtual_path for a in attachments} - unresolved = [p for p in artifacts if p not in resolved_virtuals] - if unresolved: - artifact_text = _format_artifact_text(unresolved) - response_text = (response_text + "\n\n" + artifact_text) if response_text else artifact_text - # Always include resolved attachment filenames as a text fallback so - # files remain discoverable even when the upload is skipped or fails. - if attachments: - resolved_text = _format_artifact_text([a.virtual_path for a in attachments]) - response_text = (response_text + "\n\n" + resolved_text) if response_text else resolved_text + response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts) if not response_text: if attachments: @@ -415,6 +534,103 @@ class ChannelManager: logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id) await self.bus.publish_outbound(outbound) + async def _handle_streaming_chat( + self, + client, + msg: InboundMessage, + thread_id: str, + assistant_id: str, + run_config: dict[str, Any], + run_context: dict[str, Any], + ) -> None: + logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100]) + + last_values: dict[str, Any] | list | None = None + streamed_buffers: dict[str, str] = {} + current_message_id: str | None = None + latest_text = "" + last_published_text = "" + last_publish_at = 0.0 + stream_error: BaseException | None = None + + try: + async for chunk in client.runs.stream( + thread_id, + assistant_id, + input={"messages": [{"role": "human", "content": msg.text}]}, + config=run_config, + context=run_context, + stream_mode=["messages-tuple", "values"], + ): + event = getattr(chunk, "event", "") + data = getattr(chunk, "data", None) + + if event == "messages-tuple": + accumulated_text, current_message_id = _accumulate_stream_text(streamed_buffers, current_message_id, data) + if accumulated_text: + latest_text = accumulated_text + elif event == "values" and isinstance(data, (dict, list)): + last_values = data + snapshot_text = _extract_response_text(data) + if snapshot_text: + latest_text = snapshot_text + + if not latest_text or latest_text == last_published_text: + continue + + now = time.monotonic() + if last_published_text and now - last_publish_at < STREAM_UPDATE_MIN_INTERVAL_SECONDS: + continue + + await self.bus.publish_outbound( + OutboundMessage( + channel_name=msg.channel_name, + chat_id=msg.chat_id, + thread_id=thread_id, + text=latest_text, + is_final=False, + thread_ts=msg.thread_ts, + ) + ) + last_published_text = latest_text + last_publish_at = now + except Exception as exc: + stream_error = exc + logger.exception("[Manager] streaming error: thread_id=%s", thread_id) + finally: + result = last_values if last_values is not None else {"messages": [{"type": "ai", "content": latest_text}]} + response_text = _extract_response_text(result) + artifacts = _extract_artifacts(result) + response_text, attachments = _prepare_artifact_delivery(thread_id, response_text, artifacts) + + if not response_text: + if attachments: + response_text = _format_artifact_text([attachment.virtual_path for attachment in attachments]) + elif stream_error: + response_text = "An error occurred while processing your request. Please try again." + else: + response_text = latest_text or "(No response from agent)" + + logger.info( + "[Manager] streaming response completed: thread_id=%s, response_len=%d, artifacts=%d, error=%s", + thread_id, + len(response_text), + len(artifacts), + stream_error, + ) + await self.bus.publish_outbound( + OutboundMessage( + channel_name=msg.channel_name, + chat_id=msg.chat_id, + thread_id=thread_id, + text=response_text, + artifacts=artifacts, + attachments=attachments, + is_final=True, + thread_ts=msg.thread_ts, + ) + ) + # -- command handling -------------------------------------------------- async def _handle_command(self, msg: InboundMessage) -> None: diff --git a/backend/src/channels/service.py b/backend/src/channels/service.py index 72fa3ec..1ff1de6 100644 --- a/backend/src/channels/service.py +++ b/backend/src/channels/service.py @@ -33,11 +33,7 @@ class ChannelService: langgraph_url = config.pop("langgraph_url", None) or "http://localhost:2024" gateway_url = config.pop("gateway_url", None) or "http://localhost:8001" default_session = config.pop("session", None) - channel_sessions = { - name: channel_config.get("session") - for name, channel_config in config.items() - if isinstance(channel_config, dict) - } + channel_sessions = {name: channel_config.get("session") for name, channel_config in config.items() if isinstance(channel_config, dict)} self.manager = ChannelManager( bus=self.bus, store=self.store, diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index ca37fe8..b985932 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio +import json import tempfile from pathlib import Path +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest @@ -399,6 +401,18 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None): return mock_client +def _make_stream_part(event: str, data): + return SimpleNamespace(event=event, data=data) + + +def _make_async_iterator(items): + async def iterator(): + for item in items: + yield item + + return iterator() + + class TestChannelManager: def test_handle_chat_creates_thread(self): from src.channels.manager import ChannelManager @@ -550,6 +564,126 @@ class TestChannelManager: _run(go()) + def test_handle_feishu_chat_streams_multiple_outbound_updates(self, monkeypatch): + from src.channels.manager import ChannelManager + + monkeypatch.setattr("src.channels.manager.STREAM_UPDATE_MIN_INTERVAL_SECONDS", 0.0) + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + + stream_events = [ + _make_stream_part( + "messages-tuple", + [ + {"id": "ai-1", "content": "Hello", "type": "AIMessageChunk"}, + {"langgraph_node": "agent"}, + ], + ), + _make_stream_part( + "messages-tuple", + [ + {"id": "ai-1", "content": " world", "type": "AIMessageChunk"}, + {"langgraph_node": "agent"}, + ], + ), + _make_stream_part( + "values", + { + "messages": [ + {"type": "human", "content": "hi"}, + {"type": "ai", "content": "Hello world"}, + ], + "artifacts": [], + }, + ), + ] + + mock_client = _make_mock_langgraph_client() + mock_client.runs.stream = MagicMock(return_value=_make_async_iterator(stream_events)) + manager._client = mock_client + + await manager.start() + + inbound = InboundMessage( + channel_name="feishu", + chat_id="chat1", + user_id="user1", + text="hi", + thread_ts="om-source-1", + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: len(outbound_received) >= 3) + await manager.stop() + + mock_client.runs.stream.assert_called_once() + assert [msg.text for msg in outbound_received] == ["Hello", "Hello world", "Hello world"] + assert [msg.is_final for msg in outbound_received] == [False, False, True] + assert all(msg.thread_ts == "om-source-1" for msg in outbound_received) + + _run(go()) + + def test_handle_feishu_stream_error_still_sends_final(self, monkeypatch): + """When the stream raises mid-way, a final outbound with is_final=True must still be published.""" + from src.channels.manager import ChannelManager + + monkeypatch.setattr("src.channels.manager.STREAM_UPDATE_MIN_INTERVAL_SECONDS", 0.0) + + async def go(): + bus = MessageBus() + store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") + manager = ChannelManager(bus=bus, store=store) + + outbound_received = [] + + async def capture_outbound(msg): + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + + async def _failing_stream(): + yield _make_stream_part( + "messages-tuple", + [ + {"id": "ai-1", "content": "Partial", "type": "AIMessageChunk"}, + {"langgraph_node": "agent"}, + ], + ) + raise ConnectionError("stream broken") + + mock_client = _make_mock_langgraph_client() + mock_client.runs.stream = MagicMock(return_value=_failing_stream()) + manager._client = mock_client + + await manager.start() + + inbound = InboundMessage( + channel_name="feishu", + chat_id="chat1", + user_id="user1", + text="hi", + thread_ts="om-source-1", + ) + await bus.publish_inbound(inbound) + await _wait_for(lambda: any(m.is_final for m in outbound_received)) + await manager.stop() + + # Should have at least one intermediate and one final message + final_msgs = [m for m in outbound_received if m.is_final] + assert len(final_msgs) == 1 + assert final_msgs[0].thread_ts == "om-source-1" + + _run(go()) + def test_handle_command_help(self): from src.channels.manager import ChannelManager @@ -1092,6 +1226,180 @@ class TestHandleChatWithArtifacts: _run(go()) +class TestFeishuChannel: + def test_prepare_inbound_publishes_without_waiting_for_running_card(self): + from src.channels.feishu import FeishuChannel + + async def go(): + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = FeishuChannel(bus, config={}) + + reply_started = asyncio.Event() + release_reply = asyncio.Event() + + async def slow_reply(message_id: str, text: str) -> str: + reply_started.set() + await release_reply.wait() + return "om-running-card" + + channel._add_reaction = AsyncMock() + channel._reply_card = AsyncMock(side_effect=slow_reply) + + inbound = InboundMessage( + channel_name="feishu", + chat_id="chat-1", + user_id="user-1", + text="hello", + thread_ts="om-source-msg", + ) + + prepare_task = asyncio.create_task(channel._prepare_inbound("om-source-msg", inbound)) + + await _wait_for(lambda: bus.publish_inbound.await_count == 1) + await prepare_task + + assert reply_started.is_set() + assert "om-source-msg" in channel._running_card_tasks + assert channel._reply_card.await_count == 1 + + release_reply.set() + await _wait_for(lambda: channel._running_card_ids.get("om-source-msg") == "om-running-card") + await _wait_for(lambda: "om-source-msg" not in channel._running_card_tasks) + + _run(go()) + + def test_prepare_inbound_and_send_share_running_card_task(self): + from src.channels.feishu import FeishuChannel + + async def go(): + bus = MessageBus() + bus.publish_inbound = AsyncMock() + channel = FeishuChannel(bus, config={}) + channel._api_client = MagicMock() + + reply_started = asyncio.Event() + release_reply = asyncio.Event() + + async def slow_reply(message_id: str, text: str) -> str: + reply_started.set() + await release_reply.wait() + return "om-running-card" + + channel._add_reaction = AsyncMock() + channel._reply_card = AsyncMock(side_effect=slow_reply) + channel._update_card = AsyncMock() + + inbound = InboundMessage( + channel_name="feishu", + chat_id="chat-1", + user_id="user-1", + text="hello", + thread_ts="om-source-msg", + ) + + prepare_task = asyncio.create_task(channel._prepare_inbound("om-source-msg", inbound)) + await _wait_for(lambda: bus.publish_inbound.await_count == 1) + await _wait_for(reply_started.is_set) + + send_task = asyncio.create_task( + channel.send( + OutboundMessage( + channel_name="feishu", + chat_id="chat-1", + thread_id="thread-1", + text="Hello", + is_final=False, + thread_ts="om-source-msg", + ) + ) + ) + + await asyncio.sleep(0) + assert channel._reply_card.await_count == 1 + + release_reply.set() + await prepare_task + await send_task + + assert channel._reply_card.await_count == 1 + channel._update_card.assert_awaited_once_with("om-running-card", "Hello") + assert "om-source-msg" not in channel._running_card_tasks + + _run(go()) + + def test_streaming_reuses_single_running_card(self): + from lark_oapi.api.im.v1 import ( + CreateMessageReactionRequest, + CreateMessageReactionRequestBody, + Emoji, + PatchMessageRequest, + PatchMessageRequestBody, + ReplyMessageRequest, + ReplyMessageRequestBody, + ) + + from src.channels.feishu import FeishuChannel + + async def go(): + bus = MessageBus() + channel = FeishuChannel(bus, config={}) + + channel._api_client = MagicMock() + channel._ReplyMessageRequest = ReplyMessageRequest + channel._ReplyMessageRequestBody = ReplyMessageRequestBody + channel._PatchMessageRequest = PatchMessageRequest + channel._PatchMessageRequestBody = PatchMessageRequestBody + channel._CreateMessageReactionRequest = CreateMessageReactionRequest + channel._CreateMessageReactionRequestBody = CreateMessageReactionRequestBody + channel._Emoji = Emoji + + reply_response = MagicMock() + reply_response.data.message_id = "om-running-card" + channel._api_client.im.v1.message.reply = MagicMock(return_value=reply_response) + channel._api_client.im.v1.message.patch = MagicMock() + channel._api_client.im.v1.message_reaction.create = MagicMock() + + await channel._send_running_reply("om-source-msg") + + await channel.send( + OutboundMessage( + channel_name="feishu", + chat_id="chat-1", + thread_id="thread-1", + text="Hello", + is_final=False, + thread_ts="om-source-msg", + ) + ) + await channel.send( + OutboundMessage( + channel_name="feishu", + chat_id="chat-1", + thread_id="thread-1", + text="Hello world", + is_final=True, + thread_ts="om-source-msg", + ) + ) + + assert channel._api_client.im.v1.message.reply.call_count == 1 + assert channel._api_client.im.v1.message.patch.call_count == 2 + assert channel._api_client.im.v1.message_reaction.create.call_count == 1 + assert "om-source-msg" not in channel._running_card_ids + assert "om-source-msg" not in channel._running_card_tasks + + first_patch_request = channel._api_client.im.v1.message.patch.call_args_list[0].args[0] + final_patch_request = channel._api_client.im.v1.message.patch.call_args_list[1].args[0] + assert first_patch_request.message_id == "om-running-card" + assert final_patch_request.message_id == "om-running-card" + assert json.loads(first_patch_request.body.content)["elements"][0]["content"] == "Hello" + assert json.loads(final_patch_request.body.content)["elements"][0]["content"] == "Hello world" + assert json.loads(final_patch_request.body.content)["config"]["update_multi"] is True + + _run(go()) + + class TestChannelService: def test_get_status_no_channels(self): from src.channels.service import ChannelService