feat: add real-time streaming of subagent AI messages

Enable task tool to capture and stream AI messages as they are generated by subagents. This replaces simple polling status updates with detailed message-level progress updates.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
hetaoBackend
2026-02-08 21:25:54 +08:00
parent 2b3dc96e40
commit 0a27a7561a
2 changed files with 107 additions and 51 deletions

View File

@@ -44,6 +44,7 @@ class SubagentResult:
error: Error message (if failed). error: Error message (if failed).
started_at: When execution started. started_at: When execution started.
completed_at: When execution completed. completed_at: When execution completed.
ai_messages: List of complete AI messages (as dicts) generated during execution.
""" """
task_id: str task_id: str
@@ -53,6 +54,12 @@ class SubagentResult:
error: str | None = None error: str | None = None
started_at: datetime | None = None started_at: datetime | None = None
completed_at: datetime | None = None completed_at: datetime | None = None
ai_messages: list[dict[str, Any]] | None = None
def __post_init__(self):
"""Initialize mutable defaults."""
if self.ai_messages is None:
self.ai_messages = []
# Global storage for background task results # Global storage for background task results
@@ -197,22 +204,28 @@ class SubagentExecutor:
return state return state
def execute(self, task: str) -> SubagentResult: def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
"""Execute a task synchronously. """Execute a task synchronously.
Args: Args:
task: The task description for the subagent. task: The task description for the subagent.
result_holder: Optional pre-created result object to update during execution.
Returns: Returns:
SubagentResult with the execution result. SubagentResult with the execution result.
""" """
task_id = str(uuid.uuid4())[:8] if result_holder is not None:
result = SubagentResult( # Use the provided result holder (for async execution with real-time updates)
task_id=task_id, result = result_holder
trace_id=self.trace_id, else:
status=SubagentStatus.RUNNING, # Create a new result for synchronous execution
started_at=datetime.now(), task_id = str(uuid.uuid4())[:8]
) result = SubagentResult(
task_id=task_id,
trace_id=self.trace_id,
status=SubagentStatus.RUNNING,
started_at=datetime.now(),
)
try: try:
agent = self._create_agent() agent = self._create_agent()
@@ -229,50 +242,74 @@ class SubagentExecutor:
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting execution with max_turns={self.config.max_turns}") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting execution with max_turns={self.config.max_turns}")
# Run the agent using invoke for complete result # Use stream instead of invoke to get real-time updates
# Note: invoke() runs until completion or interruption # This allows us to collect AI messages as they are generated
# Timeout is handled at the execute_async level, not here final_state = None
final_state = agent.invoke(state, config=run_config, context=context) # type: ignore[arg-type] for chunk in agent.stream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
final_state = chunk
# Extract AI messages from the current state
messages = chunk.get("messages", [])
if messages:
last_message = messages[-1]
# Check if this is a new AI message
if isinstance(last_message, AIMessage):
# Convert message to dict for serialization
message_dict = last_message.model_dump()
# Only add if it's not already in the list (avoid duplicates)
# Check by comparing message IDs if available, otherwise compare full dict
message_id = message_dict.get("id")
is_duplicate = False
if message_id:
is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages)
else:
is_duplicate = message_dict in result.ai_messages
if not is_duplicate:
result.ai_messages.append(message_dict)
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed execution") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed execution")
# Extract the final message - find the last AIMessage if final_state is None:
messages = final_state.get("messages", []) logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}")
# Find the last AIMessage in the conversation
last_ai_message = None
for msg in reversed(messages):
if isinstance(msg, AIMessage):
last_ai_message = msg
break
if last_ai_message is not None:
content = last_ai_message.content
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} last AI message content type: {type(content)}")
# Handle both str and list content types
if isinstance(content, str):
result.result = content
elif isinstance(content, list):
# Extract text from list of content blocks
text_parts = []
for block in content:
if isinstance(block, str):
text_parts.append(block)
elif isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
else:
result.result = str(content)
elif messages:
# Fallback: use the last message if no AIMessage found
last_message = messages[-1]
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
else:
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
result.result = "No response generated" result.result = "No response generated"
else:
# Extract the final message - find the last AIMessage
messages = final_state.get("messages", [])
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}")
# Find the last AIMessage in the conversation
last_ai_message = None
for msg in reversed(messages):
if isinstance(msg, AIMessage):
last_ai_message = msg
break
if last_ai_message is not None:
content = last_ai_message.content
# Handle both str and list content types for the final result
if isinstance(content, str):
result.result = content
elif isinstance(content, list):
# Extract text from list of content blocks for final result only
text_parts = []
for block in content:
if isinstance(block, str):
text_parts.append(block)
elif isinstance(block, dict) and "text" in block:
text_parts.append(block["text"])
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
else:
result.result = str(content)
elif messages:
# Fallback: use the last message if no AIMessage found
last_message = messages[-1]
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
else:
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
result.result = "No response generated"
result.status = SubagentStatus.COMPLETED result.status = SubagentStatus.COMPLETED
result.completed_at = datetime.now() result.completed_at = datetime.now()
@@ -316,10 +353,12 @@ class SubagentExecutor:
with _background_tasks_lock: with _background_tasks_lock:
_background_tasks[task_id].status = SubagentStatus.RUNNING _background_tasks[task_id].status = SubagentStatus.RUNNING
_background_tasks[task_id].started_at = datetime.now() _background_tasks[task_id].started_at = datetime.now()
result_holder = _background_tasks[task_id]
try: try:
# Submit execution to execution pool with timeout # Submit execution to execution pool with timeout
execution_future: Future = _execution_pool.submit(self.execute, task) # Pass result_holder so execute() can update it in real-time
execution_future: Future = _execution_pool.submit(self.execute, task, result_holder)
try: try:
# Wait for execution with timeout # Wait for execution with timeout
exec_result = execution_future.result(timeout=self.config.timeout_seconds) exec_result = execution_future.result(timeout=self.config.timeout_seconds)
@@ -328,6 +367,7 @@ class SubagentExecutor:
_background_tasks[task_id].result = exec_result.result _background_tasks[task_id].result = exec_result.result
_background_tasks[task_id].error = exec_result.error _background_tasks[task_id].error = exec_result.error
_background_tasks[task_id].completed_at = datetime.now() _background_tasks[task_id].completed_at = datetime.now()
_background_tasks[task_id].ai_messages = exec_result.ai_messages
except FuturesTimeoutError: except FuturesTimeoutError:
logger.error( logger.error(
f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s" f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s"

View File

@@ -6,8 +6,8 @@ import uuid
from typing import Annotated, Literal from typing import Annotated, Literal
from langchain.tools import InjectedToolCallId, ToolRuntime, tool from langchain.tools import InjectedToolCallId, ToolRuntime, tool
from langgraph.typing import ContextT
from langgraph.config import get_stream_writer from langgraph.config import get_stream_writer
from langgraph.typing import ContextT
from src.agents.thread_state import ThreadState from src.agents.thread_state import ThreadState
from src.subagents import SubagentExecutor, get_subagent_config from src.subagents import SubagentExecutor, get_subagent_config
@@ -112,6 +112,7 @@ def task_tool(
# Poll for task completion in backend (removes need for LLM to poll) # Poll for task completion in backend (removes need for LLM to poll)
poll_count = 0 poll_count = 0
last_status = None last_status = None
last_message_count = 0 # Track how many AI messages we've already sent
writer = get_stream_writer() writer = get_stream_writer()
# Send Task Started message' # Send Task Started message'
@@ -131,6 +132,22 @@ def task_tool(
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}") logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
last_status = result.status last_status = result.status
# Check for new AI messages and send task_running events
current_message_count = len(result.ai_messages)
if current_message_count > last_message_count:
# Send task_running event for each new message
for i in range(last_message_count, current_message_count):
message = result.ai_messages[i]
writer({
"type": "task_running",
"task_id": task_id,
"message": message,
"message_index": i + 1, # 1-based index for display
"total_messages": current_message_count
})
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
last_message_count = current_message_count
# Check if task completed, failed, or timed out # Check if task completed, failed, or timed out
if result.status == SubagentStatus.COMPLETED: if result.status == SubagentStatus.COMPLETED:
writer({"type": "task_completed", "task_id": task_id, "result": result.result}) writer({"type": "task_completed", "task_id": task_id, "result": result.result})
@@ -146,7 +163,6 @@ def task_tool(
return f"Task timed out. Error: {result.error}" return f"Task timed out. Error: {result.error}"
# Still running, wait before next poll # Still running, wait before next poll
writer({"type": "task_running", "task_id": task_id, "poll_count": poll_count})
time.sleep(5) # Poll every 5 seconds time.sleep(5) # Poll every 5 seconds
poll_count += 1 poll_count += 1