mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
feat: add real-time streaming of subagent AI messages
Enable task tool to capture and stream AI messages as they are generated by subagents. This replaces simple polling status updates with detailed message-level progress updates. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,7 @@ class SubagentResult:
|
||||
error: Error message (if failed).
|
||||
started_at: When execution started.
|
||||
completed_at: When execution completed.
|
||||
ai_messages: List of complete AI messages (as dicts) generated during execution.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
@@ -53,6 +54,12 @@ class SubagentResult:
|
||||
error: str | None = None
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
ai_messages: list[dict[str, Any]] | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize mutable defaults."""
|
||||
if self.ai_messages is None:
|
||||
self.ai_messages = []
|
||||
|
||||
|
||||
# Global storage for background task results
|
||||
@@ -197,22 +204,28 @@ class SubagentExecutor:
|
||||
|
||||
return state
|
||||
|
||||
def execute(self, task: str) -> SubagentResult:
|
||||
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||
"""Execute a task synchronously.
|
||||
|
||||
Args:
|
||||
task: The task description for the subagent.
|
||||
result_holder: Optional pre-created result object to update during execution.
|
||||
|
||||
Returns:
|
||||
SubagentResult with the execution result.
|
||||
"""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
result = SubagentResult(
|
||||
task_id=task_id,
|
||||
trace_id=self.trace_id,
|
||||
status=SubagentStatus.RUNNING,
|
||||
started_at=datetime.now(),
|
||||
)
|
||||
if result_holder is not None:
|
||||
# Use the provided result holder (for async execution with real-time updates)
|
||||
result = result_holder
|
||||
else:
|
||||
# Create a new result for synchronous execution
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
result = SubagentResult(
|
||||
task_id=task_id,
|
||||
trace_id=self.trace_id,
|
||||
status=SubagentStatus.RUNNING,
|
||||
started_at=datetime.now(),
|
||||
)
|
||||
|
||||
try:
|
||||
agent = self._create_agent()
|
||||
@@ -229,50 +242,74 @@ class SubagentExecutor:
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting execution with max_turns={self.config.max_turns}")
|
||||
|
||||
# Run the agent using invoke for complete result
|
||||
# Note: invoke() runs until completion or interruption
|
||||
# Timeout is handled at the execute_async level, not here
|
||||
final_state = agent.invoke(state, config=run_config, context=context) # type: ignore[arg-type]
|
||||
# Use stream instead of invoke to get real-time updates
|
||||
# This allows us to collect AI messages as they are generated
|
||||
final_state = None
|
||||
for chunk in agent.stream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||
final_state = chunk
|
||||
|
||||
# Extract AI messages from the current state
|
||||
messages = chunk.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Check if this is a new AI message
|
||||
if isinstance(last_message, AIMessage):
|
||||
# Convert message to dict for serialization
|
||||
message_dict = last_message.model_dump()
|
||||
# Only add if it's not already in the list (avoid duplicates)
|
||||
# Check by comparing message IDs if available, otherwise compare full dict
|
||||
message_id = message_dict.get("id")
|
||||
is_duplicate = False
|
||||
if message_id:
|
||||
is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages)
|
||||
else:
|
||||
is_duplicate = message_dict in result.ai_messages
|
||||
|
||||
if not is_duplicate:
|
||||
result.ai_messages.append(message_dict)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed execution")
|
||||
|
||||
# Extract the final message - find the last AIMessage
|
||||
messages = final_state.get("messages", [])
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}")
|
||||
|
||||
# Find the last AIMessage in the conversation
|
||||
last_ai_message = None
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
last_ai_message = msg
|
||||
break
|
||||
|
||||
if last_ai_message is not None:
|
||||
content = last_ai_message.content
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} last AI message content type: {type(content)}")
|
||||
|
||||
# Handle both str and list content types
|
||||
if isinstance(content, str):
|
||||
result.result = content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from list of content blocks
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
elif isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
||||
else:
|
||||
result.result = str(content)
|
||||
elif messages:
|
||||
# Fallback: use the last message if no AIMessage found
|
||||
last_message = messages[-1]
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
||||
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
|
||||
else:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
||||
if final_state is None:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
||||
result.result = "No response generated"
|
||||
else:
|
||||
# Extract the final message - find the last AIMessage
|
||||
messages = final_state.get("messages", [])
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}")
|
||||
|
||||
# Find the last AIMessage in the conversation
|
||||
last_ai_message = None
|
||||
for msg in reversed(messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
last_ai_message = msg
|
||||
break
|
||||
|
||||
if last_ai_message is not None:
|
||||
content = last_ai_message.content
|
||||
# Handle both str and list content types for the final result
|
||||
if isinstance(content, str):
|
||||
result.result = content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from list of content blocks for final result only
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
elif isinstance(block, dict) and "text" in block:
|
||||
text_parts.append(block["text"])
|
||||
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
||||
else:
|
||||
result.result = str(content)
|
||||
elif messages:
|
||||
# Fallback: use the last message if no AIMessage found
|
||||
last_message = messages[-1]
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
||||
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
|
||||
else:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
||||
result.result = "No response generated"
|
||||
|
||||
result.status = SubagentStatus.COMPLETED
|
||||
result.completed_at = datetime.now()
|
||||
@@ -316,10 +353,12 @@ class SubagentExecutor:
|
||||
with _background_tasks_lock:
|
||||
_background_tasks[task_id].status = SubagentStatus.RUNNING
|
||||
_background_tasks[task_id].started_at = datetime.now()
|
||||
result_holder = _background_tasks[task_id]
|
||||
|
||||
try:
|
||||
# Submit execution to execution pool with timeout
|
||||
execution_future: Future = _execution_pool.submit(self.execute, task)
|
||||
# Pass result_holder so execute() can update it in real-time
|
||||
execution_future: Future = _execution_pool.submit(self.execute, task, result_holder)
|
||||
try:
|
||||
# Wait for execution with timeout
|
||||
exec_result = execution_future.result(timeout=self.config.timeout_seconds)
|
||||
@@ -328,6 +367,7 @@ class SubagentExecutor:
|
||||
_background_tasks[task_id].result = exec_result.result
|
||||
_background_tasks[task_id].error = exec_result.error
|
||||
_background_tasks[task_id].completed_at = datetime.now()
|
||||
_background_tasks[task_id].ai_messages = exec_result.ai_messages
|
||||
except FuturesTimeoutError:
|
||||
logger.error(
|
||||
f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s"
|
||||
|
||||
@@ -6,8 +6,8 @@ import uuid
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.subagents import SubagentExecutor, get_subagent_config
|
||||
@@ -112,6 +112,7 @@ def task_tool(
|
||||
# Poll for task completion in backend (removes need for LLM to poll)
|
||||
poll_count = 0
|
||||
last_status = None
|
||||
last_message_count = 0 # Track how many AI messages we've already sent
|
||||
|
||||
writer = get_stream_writer()
|
||||
# Send Task Started message'
|
||||
@@ -131,6 +132,22 @@ def task_tool(
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
||||
last_status = result.status
|
||||
|
||||
# Check for new AI messages and send task_running events
|
||||
current_message_count = len(result.ai_messages)
|
||||
if current_message_count > last_message_count:
|
||||
# Send task_running event for each new message
|
||||
for i in range(last_message_count, current_message_count):
|
||||
message = result.ai_messages[i]
|
||||
writer({
|
||||
"type": "task_running",
|
||||
"task_id": task_id,
|
||||
"message": message,
|
||||
"message_index": i + 1, # 1-based index for display
|
||||
"total_messages": current_message_count
|
||||
})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
||||
last_message_count = current_message_count
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
@@ -146,7 +163,6 @@ def task_tool(
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
|
||||
# Still running, wait before next poll
|
||||
writer({"type": "task_running", "task_id": task_id, "poll_count": poll_count})
|
||||
time.sleep(5) # Poll every 5 seconds
|
||||
poll_count += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user