fix(task): avoid blocking in task tool polling (#1320)

* fix: avoid blocking in task tool polling

* test: adapt task tool polling tests for async tool

* fix: clean up cancelled task tool polling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
luo jiyin
2026-03-27 17:12:40 +08:00
committed by GitHub
parent a4e4bb21e3
commit 43a19f9627
2 changed files with 242 additions and 77 deletions

View File

@@ -1,7 +1,7 @@
"""Task tool for delegating work to subagents.""" """Task tool for delegating work to subagents."""
import asyncio
import logging import logging
import time
import uuid import uuid
from dataclasses import replace from dataclasses import replace
from typing import Annotated, Literal from typing import Annotated, Literal
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
@tool("task", parse_docstring=True) @tool("task", parse_docstring=True)
def task_tool( async def task_tool(
runtime: ToolRuntime[ContextT, ThreadState], runtime: ToolRuntime[ContextT, ThreadState],
description: str, description: str,
prompt: str, prompt: str,
@@ -129,67 +129,102 @@ def task_tool(
# Send Task Started message' # Send Task Started message'
writer({"type": "task_started", "task_id": task_id, "description": description}) writer({"type": "task_started", "task_id": task_id, "description": description})
while True: try:
result = get_background_task_result(task_id) while True:
result = get_background_task_result(task_id)
if result is None: if result is None:
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks") logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"}) writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
cleanup_background_task(task_id) cleanup_background_task(task_id)
return f"Error: Task {task_id} disappeared from background tasks" return f"Error: Task {task_id} disappeared from background tasks"
# Log status changes for debugging # Log status changes for debugging
if result.status != last_status: if result.status != last_status:
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}") logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
last_status = result.status last_status = result.status
# Check for new AI messages and send task_running events # Check for new AI messages and send task_running events
current_message_count = len(result.ai_messages) current_message_count = len(result.ai_messages)
if current_message_count > last_message_count: if current_message_count > last_message_count:
# Send task_running event for each new message # Send task_running event for each new message
for i in range(last_message_count, current_message_count): for i in range(last_message_count, current_message_count):
message = result.ai_messages[i] message = result.ai_messages[i]
writer( writer(
{ {
"type": "task_running", "type": "task_running",
"task_id": task_id, "task_id": task_id,
"message": message, "message": message,
"message_index": i + 1, # 1-based index for display "message_index": i + 1, # 1-based index for display
"total_messages": current_message_count, "total_messages": current_message_count,
} }
) )
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}") logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
last_message_count = current_message_count last_message_count = current_message_count
# Check if task completed, failed, or timed out # Check if task completed, failed, or timed out
if result.status == SubagentStatus.COMPLETED: if result.status == SubagentStatus.COMPLETED:
writer({"type": "task_completed", "task_id": task_id, "result": result.result}) writer({"type": "task_completed", "task_id": task_id, "result": result.result})
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
cleanup_background_task(task_id) cleanup_background_task(task_id)
return f"Task Succeeded. Result: {result.result}" return f"Task Succeeded. Result: {result.result}"
elif result.status == SubagentStatus.FAILED: elif result.status == SubagentStatus.FAILED:
writer({"type": "task_failed", "task_id": task_id, "error": result.error}) writer({"type": "task_failed", "task_id": task_id, "error": result.error})
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id) cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}" return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.TIMED_OUT: elif result.status == SubagentStatus.TIMED_OUT:
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
cleanup_background_task(task_id) cleanup_background_task(task_id)
return f"Task timed out. Error: {result.error}" return f"Task timed out. Error: {result.error}"
# Still running, wait before next poll # Still running, wait before next poll
time.sleep(5) # Poll every 5 seconds await asyncio.sleep(5)
poll_count += 1 poll_count += 1
# Polling timeout as a safety net (in case thread pool timeout doesn't work) # Polling timeout as a safety net (in case thread pool timeout doesn't work)
# Set to execution timeout + 60s buffer, in 5s poll intervals # Set to execution timeout + 60s buffer, in 5s poll intervals
# This catches edge cases where the background task gets stuck # This catches edge cases where the background task gets stuck
# Note: We don't call cleanup_background_task here because the task may # Note: We don't call cleanup_background_task here because the task may
# still be running in the background. The cleanup will happen when the # still be running in the background. The cleanup will happen when the
# executor completes and sets a terminal status. # executor completes and sets a terminal status.
if poll_count > max_poll_count: if poll_count > max_poll_count:
timeout_minutes = config.timeout_seconds // 60 timeout_minutes = config.timeout_seconds // 60
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
writer({"type": "task_timed_out", "task_id": task_id}) writer({"type": "task_timed_out", "task_id": task_id})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError:
async def cleanup_when_done() -> None:
max_cleanup_polls = max_poll_count
cleanup_poll_count = 0
while True:
result = get_background_task_result(task_id)
if result is None:
return
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
cleanup_background_task(task_id)
return
if cleanup_poll_count > max_cleanup_polls:
logger.warning(
f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls"
)
return
await asyncio.sleep(5)
cleanup_poll_count += 1
def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None:
if cleanup_task.cancelled():
return
exc = cleanup_task.exception()
if exc is not None:
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure)
raise

View File

@@ -1,10 +1,13 @@
"""Core behavior tests for task tool orchestration.""" """Core behavior tests for task tool orchestration."""
import asyncio
import importlib import importlib
from enum import Enum from enum import Enum
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from deerflow.subagents.config import SubagentConfig from deerflow.subagents.config import SubagentConfig
# Use module import so tests can patch the exact symbols referenced inside task_tool(). # Use module import so tests can patch the exact symbols referenced inside task_tool().
@@ -61,10 +64,23 @@ def _make_result(
) )
def _run_task_tool(**kwargs):
return asyncio.run(task_tool_module.task_tool.coroutine(**kwargs))
async def _no_sleep(_: float) -> None:
return None
class _DummyScheduledTask:
def add_done_callback(self, _callback):
return None
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch): def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None) monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
result = task_tool_module.task_tool.func( result = _run_task_tool(
runtime=None, runtime=None,
description="执行任务", description="执行任务",
prompt="do work", prompt="do work",
@@ -109,11 +125,11 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix") monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses)) monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function. # task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools) monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=runtime, runtime=runtime,
description="运行子任务", description="运行子任务",
prompt="collect diagnostics", prompt="collect diagnostics",
@@ -155,10 +171,10 @@ def test_task_tool_returns_failed_message(monkeypatch):
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"), lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
) )
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=_make_runtime(), runtime=_make_runtime(),
description="执行任务", description="执行任务",
prompt="do fail", prompt="do fail",
@@ -189,10 +205,10 @@ def test_task_tool_returns_timed_out_message(monkeypatch):
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"), lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
) )
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=_make_runtime(), runtime=_make_runtime(),
description="执行任务", description="执行任务",
prompt="do timeout", prompt="do timeout",
@@ -225,10 +241,10 @@ def test_task_tool_polling_safety_timeout(monkeypatch):
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
) )
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=_make_runtime(), runtime=_make_runtime(),
description="执行任务", description="执行任务",
prompt="never finish", prompt="never finish",
@@ -261,7 +277,7 @@ def test_cleanup_called_on_completed(monkeypatch):
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"), lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
) )
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr( monkeypatch.setattr(
task_tool_module, task_tool_module,
@@ -269,7 +285,7 @@ def test_cleanup_called_on_completed(monkeypatch):
lambda task_id: cleanup_calls.append(task_id), lambda task_id: cleanup_calls.append(task_id),
) )
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=_make_runtime(), runtime=_make_runtime(),
description="执行任务", description="执行任务",
prompt="complete task", prompt="complete task",
@@ -301,7 +317,7 @@ def test_cleanup_called_on_failed(monkeypatch):
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"), lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
) )
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr( monkeypatch.setattr(
task_tool_module, task_tool_module,
@@ -309,7 +325,7 @@ def test_cleanup_called_on_failed(monkeypatch):
lambda task_id: cleanup_calls.append(task_id), lambda task_id: cleanup_calls.append(task_id),
) )
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=_make_runtime(), runtime=_make_runtime(),
description="执行任务", description="执行任务",
prompt="fail task", prompt="fail task",
@@ -341,7 +357,7 @@ def test_cleanup_called_on_timed_out(monkeypatch):
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"), lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
) )
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr( monkeypatch.setattr(
task_tool_module, task_tool_module,
@@ -349,7 +365,7 @@ def test_cleanup_called_on_timed_out(monkeypatch):
lambda task_id: cleanup_calls.append(task_id), lambda task_id: cleanup_calls.append(task_id),
) )
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=_make_runtime(), runtime=_make_runtime(),
description="执行任务", description="执行任务",
prompt="timeout task", prompt="timeout task",
@@ -388,7 +404,7 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
) )
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr( monkeypatch.setattr(
task_tool_module, task_tool_module,
@@ -396,7 +412,7 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
lambda task_id: cleanup_calls.append(task_id), lambda task_id: cleanup_calls.append(task_id),
) )
output = task_tool_module.task_tool.func( output = _run_task_tool(
runtime=_make_runtime(), runtime=_make_runtime(),
description="执行任务", description="执行任务",
prompt="never finish", prompt="never finish",
@@ -407,3 +423,117 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
assert output.startswith("Task polling timed out after 0 minutes") assert output.startswith("Task polling timed out after 0 minutes")
# cleanup should NOT be called because the task is still RUNNING # cleanup should NOT be called because the task is still RUNNING
assert cleanup_calls == [] assert cleanup_calls == []
def test_cleanup_scheduled_on_cancellation(monkeypatch):
"""Verify cancellation schedules deferred cleanup for the background task."""
config = _make_subagent_config()
events = []
cleanup_calls = []
scheduled_cleanup_coros = []
poll_count = 0
def get_result(_: str):
nonlocal poll_count
poll_count += 1
if poll_count == 1:
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
return _make_result(FakeSubagentStatus.COMPLETED, result="done")
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
"cleanup_background_task",
lambda task_id: cleanup_calls.append(task_id),
)
with pytest.raises(asyncio.CancelledError):
_run_task_tool(
runtime=_make_runtime(),
description="执行任务",
prompt="cancel task",
subagent_type="general-purpose",
tool_call_id="tc-cancelled-cleanup",
)
assert cleanup_calls == []
assert len(scheduled_cleanup_coros) == 1
asyncio.run(scheduled_cleanup_coros.pop())
assert cleanup_calls == ["tc-cancelled-cleanup"]
def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
"""Verify deferred cleanup gives up after a bounded number of polls."""
config = _make_subagent_config()
config.timeout_seconds = 1
events = []
cleanup_calls = []
scheduled_cleanup_coros = []
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
"cleanup_background_task",
lambda task_id: cleanup_calls.append(task_id),
)
with pytest.raises(asyncio.CancelledError):
_run_task_tool(
runtime=_make_runtime(),
description="执行任务",
prompt="cancel task",
subagent_type="general-purpose",
tool_call_id="tc-cancelled-timeout",
)
async def bounded_sleep(_seconds: float) -> None:
return None
monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep)
asyncio.run(scheduled_cleanup_coros.pop())
assert cleanup_calls == []