mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
feat: add real-time streaming of subagent AI messages
Enable task tool to capture and stream AI messages as they are generated by subagents. This replaces simple polling status updates with detailed message-level progress updates. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,7 @@ class SubagentResult:
|
|||||||
error: Error message (if failed).
|
error: Error message (if failed).
|
||||||
started_at: When execution started.
|
started_at: When execution started.
|
||||||
completed_at: When execution completed.
|
completed_at: When execution completed.
|
||||||
|
ai_messages: List of complete AI messages (as dicts) generated during execution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
@@ -53,6 +54,12 @@ class SubagentResult:
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
started_at: datetime | None = None
|
started_at: datetime | None = None
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
|
ai_messages: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Initialize mutable defaults."""
|
||||||
|
if self.ai_messages is None:
|
||||||
|
self.ai_messages = []
|
||||||
|
|
||||||
|
|
||||||
# Global storage for background task results
|
# Global storage for background task results
|
||||||
@@ -197,22 +204,28 @@ class SubagentExecutor:
|
|||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def execute(self, task: str) -> SubagentResult:
|
def execute(self, task: str, result_holder: SubagentResult | None = None) -> SubagentResult:
|
||||||
"""Execute a task synchronously.
|
"""Execute a task synchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: The task description for the subagent.
|
task: The task description for the subagent.
|
||||||
|
result_holder: Optional pre-created result object to update during execution.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SubagentResult with the execution result.
|
SubagentResult with the execution result.
|
||||||
"""
|
"""
|
||||||
task_id = str(uuid.uuid4())[:8]
|
if result_holder is not None:
|
||||||
result = SubagentResult(
|
# Use the provided result holder (for async execution with real-time updates)
|
||||||
task_id=task_id,
|
result = result_holder
|
||||||
trace_id=self.trace_id,
|
else:
|
||||||
status=SubagentStatus.RUNNING,
|
# Create a new result for synchronous execution
|
||||||
started_at=datetime.now(),
|
task_id = str(uuid.uuid4())[:8]
|
||||||
)
|
result = SubagentResult(
|
||||||
|
task_id=task_id,
|
||||||
|
trace_id=self.trace_id,
|
||||||
|
status=SubagentStatus.RUNNING,
|
||||||
|
started_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = self._create_agent()
|
agent = self._create_agent()
|
||||||
@@ -229,50 +242,74 @@ class SubagentExecutor:
|
|||||||
|
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting execution with max_turns={self.config.max_turns}")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting execution with max_turns={self.config.max_turns}")
|
||||||
|
|
||||||
# Run the agent using invoke for complete result
|
# Use stream instead of invoke to get real-time updates
|
||||||
# Note: invoke() runs until completion or interruption
|
# This allows us to collect AI messages as they are generated
|
||||||
# Timeout is handled at the execute_async level, not here
|
final_state = None
|
||||||
final_state = agent.invoke(state, config=run_config, context=context) # type: ignore[arg-type]
|
for chunk in agent.stream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||||
|
final_state = chunk
|
||||||
|
|
||||||
|
# Extract AI messages from the current state
|
||||||
|
messages = chunk.get("messages", [])
|
||||||
|
if messages:
|
||||||
|
last_message = messages[-1]
|
||||||
|
# Check if this is a new AI message
|
||||||
|
if isinstance(last_message, AIMessage):
|
||||||
|
# Convert message to dict for serialization
|
||||||
|
message_dict = last_message.model_dump()
|
||||||
|
# Only add if it's not already in the list (avoid duplicates)
|
||||||
|
# Check by comparing message IDs if available, otherwise compare full dict
|
||||||
|
message_id = message_dict.get("id")
|
||||||
|
is_duplicate = False
|
||||||
|
if message_id:
|
||||||
|
is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages)
|
||||||
|
else:
|
||||||
|
is_duplicate = message_dict in result.ai_messages
|
||||||
|
|
||||||
|
if not is_duplicate:
|
||||||
|
result.ai_messages.append(message_dict)
|
||||||
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
|
||||||
|
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed execution")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed execution")
|
||||||
|
|
||||||
# Extract the final message - find the last AIMessage
|
if final_state is None:
|
||||||
messages = final_state.get("messages", [])
|
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}")
|
|
||||||
|
|
||||||
# Find the last AIMessage in the conversation
|
|
||||||
last_ai_message = None
|
|
||||||
for msg in reversed(messages):
|
|
||||||
if isinstance(msg, AIMessage):
|
|
||||||
last_ai_message = msg
|
|
||||||
break
|
|
||||||
|
|
||||||
if last_ai_message is not None:
|
|
||||||
content = last_ai_message.content
|
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} last AI message content type: {type(content)}")
|
|
||||||
|
|
||||||
# Handle both str and list content types
|
|
||||||
if isinstance(content, str):
|
|
||||||
result.result = content
|
|
||||||
elif isinstance(content, list):
|
|
||||||
# Extract text from list of content blocks
|
|
||||||
text_parts = []
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, str):
|
|
||||||
text_parts.append(block)
|
|
||||||
elif isinstance(block, dict) and "text" in block:
|
|
||||||
text_parts.append(block["text"])
|
|
||||||
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
|
||||||
else:
|
|
||||||
result.result = str(content)
|
|
||||||
elif messages:
|
|
||||||
# Fallback: use the last message if no AIMessage found
|
|
||||||
last_message = messages[-1]
|
|
||||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
|
||||||
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
|
|
||||||
else:
|
|
||||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
|
||||||
result.result = "No response generated"
|
result.result = "No response generated"
|
||||||
|
else:
|
||||||
|
# Extract the final message - find the last AIMessage
|
||||||
|
messages = final_state.get("messages", [])
|
||||||
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} final messages count: {len(messages)}")
|
||||||
|
|
||||||
|
# Find the last AIMessage in the conversation
|
||||||
|
last_ai_message = None
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if isinstance(msg, AIMessage):
|
||||||
|
last_ai_message = msg
|
||||||
|
break
|
||||||
|
|
||||||
|
if last_ai_message is not None:
|
||||||
|
content = last_ai_message.content
|
||||||
|
# Handle both str and list content types for the final result
|
||||||
|
if isinstance(content, str):
|
||||||
|
result.result = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Extract text from list of content blocks for final result only
|
||||||
|
text_parts = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
text_parts.append(block)
|
||||||
|
elif isinstance(block, dict) and "text" in block:
|
||||||
|
text_parts.append(block["text"])
|
||||||
|
result.result = "\n".join(text_parts) if text_parts else "No text content in response"
|
||||||
|
else:
|
||||||
|
result.result = str(content)
|
||||||
|
elif messages:
|
||||||
|
# Fallback: use the last message if no AIMessage found
|
||||||
|
last_message = messages[-1]
|
||||||
|
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}")
|
||||||
|
result.result = str(last_message.content) if hasattr(last_message, "content") else str(last_message)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state")
|
||||||
|
result.result = "No response generated"
|
||||||
|
|
||||||
result.status = SubagentStatus.COMPLETED
|
result.status = SubagentStatus.COMPLETED
|
||||||
result.completed_at = datetime.now()
|
result.completed_at = datetime.now()
|
||||||
@@ -316,10 +353,12 @@ class SubagentExecutor:
|
|||||||
with _background_tasks_lock:
|
with _background_tasks_lock:
|
||||||
_background_tasks[task_id].status = SubagentStatus.RUNNING
|
_background_tasks[task_id].status = SubagentStatus.RUNNING
|
||||||
_background_tasks[task_id].started_at = datetime.now()
|
_background_tasks[task_id].started_at = datetime.now()
|
||||||
|
result_holder = _background_tasks[task_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Submit execution to execution pool with timeout
|
# Submit execution to execution pool with timeout
|
||||||
execution_future: Future = _execution_pool.submit(self.execute, task)
|
# Pass result_holder so execute() can update it in real-time
|
||||||
|
execution_future: Future = _execution_pool.submit(self.execute, task, result_holder)
|
||||||
try:
|
try:
|
||||||
# Wait for execution with timeout
|
# Wait for execution with timeout
|
||||||
exec_result = execution_future.result(timeout=self.config.timeout_seconds)
|
exec_result = execution_future.result(timeout=self.config.timeout_seconds)
|
||||||
@@ -328,6 +367,7 @@ class SubagentExecutor:
|
|||||||
_background_tasks[task_id].result = exec_result.result
|
_background_tasks[task_id].result = exec_result.result
|
||||||
_background_tasks[task_id].error = exec_result.error
|
_background_tasks[task_id].error = exec_result.error
|
||||||
_background_tasks[task_id].completed_at = datetime.now()
|
_background_tasks[task_id].completed_at = datetime.now()
|
||||||
|
_background_tasks[task_id].ai_messages = exec_result.ai_messages
|
||||||
except FuturesTimeoutError:
|
except FuturesTimeoutError:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s"
|
f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s"
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import uuid
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||||
from langgraph.typing import ContextT
|
|
||||||
from langgraph.config import get_stream_writer
|
from langgraph.config import get_stream_writer
|
||||||
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
from src.agents.thread_state import ThreadState
|
from src.agents.thread_state import ThreadState
|
||||||
from src.subagents import SubagentExecutor, get_subagent_config
|
from src.subagents import SubagentExecutor, get_subagent_config
|
||||||
@@ -112,6 +112,7 @@ def task_tool(
|
|||||||
# Poll for task completion in backend (removes need for LLM to poll)
|
# Poll for task completion in backend (removes need for LLM to poll)
|
||||||
poll_count = 0
|
poll_count = 0
|
||||||
last_status = None
|
last_status = None
|
||||||
|
last_message_count = 0 # Track how many AI messages we've already sent
|
||||||
|
|
||||||
writer = get_stream_writer()
|
writer = get_stream_writer()
|
||||||
# Send Task Started message'
|
# Send Task Started message'
|
||||||
@@ -131,6 +132,22 @@ def task_tool(
|
|||||||
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
||||||
last_status = result.status
|
last_status = result.status
|
||||||
|
|
||||||
|
# Check for new AI messages and send task_running events
|
||||||
|
current_message_count = len(result.ai_messages)
|
||||||
|
if current_message_count > last_message_count:
|
||||||
|
# Send task_running event for each new message
|
||||||
|
for i in range(last_message_count, current_message_count):
|
||||||
|
message = result.ai_messages[i]
|
||||||
|
writer({
|
||||||
|
"type": "task_running",
|
||||||
|
"task_id": task_id,
|
||||||
|
"message": message,
|
||||||
|
"message_index": i + 1, # 1-based index for display
|
||||||
|
"total_messages": current_message_count
|
||||||
|
})
|
||||||
|
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
||||||
|
last_message_count = current_message_count
|
||||||
|
|
||||||
# Check if task completed, failed, or timed out
|
# Check if task completed, failed, or timed out
|
||||||
if result.status == SubagentStatus.COMPLETED:
|
if result.status == SubagentStatus.COMPLETED:
|
||||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||||
@@ -146,7 +163,6 @@ def task_tool(
|
|||||||
return f"Task timed out. Error: {result.error}"
|
return f"Task timed out. Error: {result.error}"
|
||||||
|
|
||||||
# Still running, wait before next poll
|
# Still running, wait before next poll
|
||||||
writer({"type": "task_running", "task_id": task_id, "poll_count": poll_count})
|
|
||||||
time.sleep(5) # Poll every 5 seconds
|
time.sleep(5) # Poll every 5 seconds
|
||||||
poll_count += 1
|
poll_count += 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user