From 96bace7ab6d68d148ebb58b85918edcd57790ee9 Mon Sep 17 00:00:00 2001 From: hetaoBackend Date: Sun, 8 Feb 2026 21:25:54 +0800 Subject: [PATCH] feat: add real-time streaming of subagent AI messages Enable task tool to capture and stream AI messages as they are generated by subagents. This replaces simple polling status updates with detailed message-level progress updates. Co-Authored-By: Claude Sonnet 4.5 --- backend/src/subagents/executor.py | 138 +++++++++++++++--------- backend/src/tools/builtins/task_tool.py | 20 +++- 2 files changed, 107 insertions(+), 51 deletions(-) diff --git a/backend/src/subagents/executor.py b/backend/src/subagents/executor.py index b58aa9a..e8532b4 100644 --- a/backend/src/subagents/executor.py +++ b/backend/src/subagents/executor.py @@ -44,6 +44,7 @@ class SubagentResult: error: Error message (if failed). started_at: When execution started. completed_at: When execution completed. + ai_messages: List of complete AI messages (as dicts) generated during execution. """ task_id: str @@ -53,6 +54,12 @@ class SubagentResult: error: str | None = None started_at: datetime | None = None completed_at: datetime | None = None + ai_messages: list[dict[str, Any]] | None = None + + def __post_init__(self): + """Initialize mutable defaults.""" + if self.ai_messages is None: + self.ai_messages = [] # Global storage for background task results @@ -197,22 +204,28 @@ class SubagentExecutor: return state - def execute(self, task: str) -> SubagentResult: + def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult: """Execute a task synchronously. Args: task: The task description for the subagent. + result_holder: Optional pre-created result object to update during execution. Returns: SubagentResult with the execution result. """ - task_id = str(uuid.uuid4())[:8] - result = SubagentResult( - task_id=task_id, - trace_id=self.trace_id, - status=SubagentStatus.RUNNING, - started_at=datetime.now(), - ) + if result_holder is not None: + # Use the provided result holder (for async execution with real-time updates) + result = result_holder + else: + # Create a new result for synchronous execution + task_id = str(uuid.uuid4())[:8] + result = SubagentResult( + task_id=task_id, + trace_id=self.trace_id, + status=SubagentStatus.RUNNING, + started_at=datetime.now(), + ) try: agent = self._create_agent() @@ -229,50 +242,74 @@ class SubagentExecutor: logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting execution with max_turns={self.config.max_turns}") - # Run the agent using invoke for complete result - # Note: invoke() runs until completion or interruption - # Timeout is handled at the execute_async level, not here - final_state = agent.invoke(state, config=run_config, context=context) # type: ignore[arg-type] + # Use stream instead of invoke to get real-time updates + # This allows us to collect AI messages as they are generated + final_state = None + for chunk in agent.stream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type] + final_state = chunk + + # Extract AI messages from the current state + messages = chunk.get("messages", []) + if messages: + last_message = messages[-1] + # Check if this is a new AI message + if isinstance(last_message, AIMessage): + # Convert message to dict for serialization + message_dict = last_message.model_dump() + # Only add if it's not already in the list (avoid duplicates) + # Check by comparing message IDs if available, otherwise compare full dict + message_id = message_dict.get("id") + is_duplicate = False + if message_id: + is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages) + else: + is_duplicate = message_dict in result.ai_messages + + if not is_duplicate: + result.ai_messages.append(message_dict) + logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed execution") - # Extract the final message - find the last AIMessage - messages = final_state.get("messages", []) - logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}") - - # Find the last AIMessage in the conversation - last_ai_message = None - for msg in reversed(messages): - if isinstance(msg, AIMessage): - last_ai_message = msg - break - - if last_ai_message is not None: - content = last_ai_message.content - logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} last AI message content type: {type(content)}") - - # Handle both str and list content types - if isinstance(content, str): - result.result = content - elif isinstance(content, list): - # Extract text from list of content blocks - text_parts = [] - for block in content: - if isinstance(block, str): - text_parts.append(block) - elif isinstance(block, dict) and "text" in block: - text_parts.append(block["text"]) - result.result = "\n".join(text_parts) if text_parts else "No text content in response" - else: - result.result = str(content) - elif messages: - # Fallback: use the last message if no AIMessage found - last_message = messages[-1] - logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}") - result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message) - else: - logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state") + if final_state is None: + logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state") result.result = "No response generated" + else: + # Extract the final message - find the last AIMessage + messages = final_state.get("messages", []) + logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}") + + # Find the last AIMessage in the conversation + last_ai_message = None + for msg in reversed(messages): + if isinstance(msg, AIMessage): + last_ai_message = msg + break + + if last_ai_message is not None: + content = last_ai_message.content + # Handle both str and list content types for the final result + if isinstance(content, str): + result.result = content + elif isinstance(content, list): + # Extract text from list of content blocks for final result only + text_parts = [] + for block in content: + if isinstance(block, str): + text_parts.append(block) + elif isinstance(block, dict) and "text" in block: + text_parts.append(block["text"]) + result.result = "\n".join(text_parts) if text_parts else "No text content in response" + else: + result.result = str(content) + elif messages: + # Fallback: use the last message if no AIMessage found + last_message = messages[-1] + logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}") + result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message) + else: + logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state") + result.result = "No response generated" result.status = SubagentStatus.COMPLETED result.completed_at = datetime.now() @@ -316,10 +353,12 @@ class SubagentExecutor: with _background_tasks_lock: _background_tasks[task_id].status = SubagentStatus.RUNNING _background_tasks[task_id].started_at = datetime.now() + result_holder = _background_tasks[task_id] try: # Submit execution to execution pool with timeout - execution_future: Future = _execution_pool.submit(self.execute, task) + # Pass result_holder so execute() can update it in real-time + execution_future: Future = _execution_pool.submit(self.execute, task, result_holder) try: # Wait for execution with timeout exec_result = execution_future.result(timeout=self.config.timeout_seconds) @@ -328,6 +367,7 @@ class SubagentExecutor: _background_tasks[task_id].result = exec_result.result _background_tasks[task_id].error = exec_result.error _background_tasks[task_id].completed_at = datetime.now() + _background_tasks[task_id].ai_messages = exec_result.ai_messages except FuturesTimeoutError: logger.error( f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s" diff --git a/backend/src/tools/builtins/task_tool.py b/backend/src/tools/builtins/task_tool.py index 36bf2ae..32560ea 100644 --- a/backend/src/tools/builtins/task_tool.py +++ b/backend/src/tools/builtins/task_tool.py @@ -6,8 +6,8 @@ import uuid from typing import Annotated, Literal from langchain.tools import InjectedToolCallId, ToolRuntime, tool -from langgraph.typing import ContextT from langgraph.config import get_stream_writer +from langgraph.typing import ContextT from src.agents.thread_state import ThreadState from src.subagents import SubagentExecutor, get_subagent_config @@ -112,6 +112,7 @@ def task_tool( # Poll for task completion in backend (removes need for LLM to poll) poll_count = 0 last_status = None + last_message_count = 0 # Track how many AI messages we've already sent writer = get_stream_writer() # Send Task Started message' @@ -131,6 +132,22 @@ def task_tool( logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}") last_status = result.status + # Check for new AI messages and send task_running events + current_message_count = len(result.ai_messages) + if current_message_count > last_message_count: + # Send task_running event for each new message + for i in range(last_message_count, current_message_count): + message = result.ai_messages[i] + writer({ + "type": "task_running", + "task_id": task_id, + "message": message, + "message_index": i + 1, # 1-based index for display + "total_messages": current_message_count + }) + logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}") + last_message_count = current_message_count + # Check if task completed, failed, or timed out if result.status == SubagentStatus.COMPLETED: writer({"type": "task_completed", "task_id": task_id, "result": result.result}) @@ -146,7 +163,6 @@ def task_tool( return f"Task timed out. Error: {result.error}" # Still running, wait before next poll - writer({"type": "task_running", "task_id": task_id, "poll_count": poll_count}) time.sleep(5) # Poll every 5 seconds poll_count += 1