mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 12:24:46 +08:00
fix(task): avoid blocking in task tool polling (#1320)
* fix: avoid blocking in task tool polling * test: adapt task tool polling tests for async tool * fix: clean up cancelled task tool polling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Task tool for delegating work to subagents."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import replace
|
||||
from typing import Annotated, Literal
|
||||
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool("task", parse_docstring=True)
|
||||
def task_tool(
|
||||
async def task_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
prompt: str,
|
||||
@@ -129,67 +129,102 @@ def task_tool(
|
||||
# Send Task Started message'
|
||||
writer({"type": "task_started", "task_id": task_id, "description": description})
|
||||
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
try:
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
|
||||
if result is None:
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
|
||||
cleanup_background_task(task_id)
|
||||
return f"Error: Task {task_id} disappeared from background tasks"
|
||||
if result is None:
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
|
||||
cleanup_background_task(task_id)
|
||||
return f"Error: Task {task_id} disappeared from background tasks"
|
||||
|
||||
# Log status changes for debugging
|
||||
if result.status != last_status:
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
||||
last_status = result.status
|
||||
# Log status changes for debugging
|
||||
if result.status != last_status:
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
||||
last_status = result.status
|
||||
|
||||
# Check for new AI messages and send task_running events
|
||||
current_message_count = len(result.ai_messages)
|
||||
if current_message_count > last_message_count:
|
||||
# Send task_running event for each new message
|
||||
for i in range(last_message_count, current_message_count):
|
||||
message = result.ai_messages[i]
|
||||
writer(
|
||||
{
|
||||
"type": "task_running",
|
||||
"task_id": task_id,
|
||||
"message": message,
|
||||
"message_index": i + 1, # 1-based index for display
|
||||
"total_messages": current_message_count,
|
||||
}
|
||||
)
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
||||
last_message_count = current_message_count
|
||||
# Check for new AI messages and send task_running events
|
||||
current_message_count = len(result.ai_messages)
|
||||
if current_message_count > last_message_count:
|
||||
# Send task_running event for each new message
|
||||
for i in range(last_message_count, current_message_count):
|
||||
message = result.ai_messages[i]
|
||||
writer(
|
||||
{
|
||||
"type": "task_running",
|
||||
"task_id": task_id,
|
||||
"message": message,
|
||||
"message_index": i + 1, # 1-based index for display
|
||||
"total_messages": current_message_count,
|
||||
}
|
||||
)
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
||||
last_message_count = current_message_count
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
# Check if task completed, failed, or timed out
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task timed out. Error: {result.error}"
|
||||
|
||||
# Still running, wait before next poll
|
||||
time.sleep(5) # Poll every 5 seconds
|
||||
poll_count += 1
|
||||
# Still running, wait before next poll
|
||||
await asyncio.sleep(5)
|
||||
poll_count += 1
|
||||
|
||||
# Polling timeout as a safety net (in case thread pool timeout doesn't work)
|
||||
# Set to execution timeout + 60s buffer, in 5s poll intervals
|
||||
# This catches edge cases where the background task gets stuck
|
||||
# Note: We don't call cleanup_background_task here because the task may
|
||||
# still be running in the background. The cleanup will happen when the
|
||||
# executor completes and sets a terminal status.
|
||||
if poll_count > max_poll_count:
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
# Polling timeout as a safety net (in case thread pool timeout doesn't work)
|
||||
# Set to execution timeout + 60s buffer, in 5s poll intervals
|
||||
# This catches edge cases where the background task gets stuck
|
||||
# Note: We don't call cleanup_background_task here because the task may
|
||||
# still be running in the background. The cleanup will happen when the
|
||||
# executor completes and sets a terminal status.
|
||||
if poll_count > max_poll_count:
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
async def cleanup_when_done() -> None:
|
||||
max_cleanup_polls = max_poll_count
|
||||
cleanup_poll_count = 0
|
||||
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
|
||||
cleanup_background_task(task_id)
|
||||
return
|
||||
|
||||
if cleanup_poll_count > max_cleanup_polls:
|
||||
logger.warning(
|
||||
f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls"
|
||||
)
|
||||
return
|
||||
|
||||
await asyncio.sleep(5)
|
||||
cleanup_poll_count += 1
|
||||
|
||||
def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None:
|
||||
if cleanup_task.cancelled():
|
||||
return
|
||||
|
||||
exc = cleanup_task.exception()
|
||||
if exc is not None:
|
||||
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
||||
|
||||
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
||||
asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user