diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 71b348c..9f3f13b 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -1,7 +1,7 @@ """Task tool for delegating work to subagents.""" +import asyncio import logging -import time import uuid from dataclasses import replace from typing import Annotated, Literal @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) @tool("task", parse_docstring=True) -def task_tool( +async def task_tool( runtime: ToolRuntime[ContextT, ThreadState], description: str, prompt: str, @@ -129,67 +129,102 @@ def task_tool( # Send Task Started message' writer({"type": "task_started", "task_id": task_id, "description": description}) - while True: - result = get_background_task_result(task_id) + try: + while True: + result = get_background_task_result(task_id) - if result is None: - logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks") - writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"}) - cleanup_background_task(task_id) - return f"Error: Task {task_id} disappeared from background tasks" + if result is None: + logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks") + writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"}) + cleanup_background_task(task_id) + return f"Error: Task {task_id} disappeared from background tasks" - # Log status changes for debugging - if result.status != last_status: - logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}") - last_status = result.status + # Log status changes for debugging + if result.status != last_status: + logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}") + last_status = result.status - # Check for new AI messages and send task_running events - current_message_count = len(result.ai_messages) - if current_message_count > last_message_count: - # Send task_running event for each new message - for i in range(last_message_count, current_message_count): - message = result.ai_messages[i] - writer( - { - "type": "task_running", - "task_id": task_id, - "message": message, - "message_index": i + 1, # 1-based index for display - "total_messages": current_message_count, - } - ) - logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}") - last_message_count = current_message_count + # Check for new AI messages and send task_running events + current_message_count = len(result.ai_messages) + if current_message_count > last_message_count: + # Send task_running event for each new message + for i in range(last_message_count, current_message_count): + message = result.ai_messages[i] + writer( + { + "type": "task_running", + "task_id": task_id, + "message": message, + "message_index": i + 1, # 1-based index for display + "total_messages": current_message_count, + } + ) + logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}") + last_message_count = current_message_count - # Check if task completed, failed, or timed out - if result.status == SubagentStatus.COMPLETED: - writer({"type": "task_completed", "task_id": task_id, "result": result.result}) - logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") - cleanup_background_task(task_id) - return f"Task Succeeded. Result: {result.result}" - elif result.status == SubagentStatus.FAILED: - writer({"type": "task_failed", "task_id": task_id, "error": result.error}) - logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") - cleanup_background_task(task_id) - return f"Task failed. Error: {result.error}" - elif result.status == SubagentStatus.TIMED_OUT: - writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) - logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") - cleanup_background_task(task_id) - return f"Task timed out. Error: {result.error}" + # Check if task completed, failed, or timed out + if result.status == SubagentStatus.COMPLETED: + writer({"type": "task_completed", "task_id": task_id, "result": result.result}) + logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") + cleanup_background_task(task_id) + return f"Task Succeeded. Result: {result.result}" + elif result.status == SubagentStatus.FAILED: + writer({"type": "task_failed", "task_id": task_id, "error": result.error}) + logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") + cleanup_background_task(task_id) + return f"Task failed. Error: {result.error}" + elif result.status == SubagentStatus.TIMED_OUT: + writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) + logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") + cleanup_background_task(task_id) + return f"Task timed out. Error: {result.error}" - # Still running, wait before next poll - time.sleep(5) # Poll every 5 seconds - poll_count += 1 + # Still running, wait before next poll + await asyncio.sleep(5) + poll_count += 1 - # Polling timeout as a safety net (in case thread pool timeout doesn't work) - # Set to execution timeout + 60s buffer, in 5s poll intervals - # This catches edge cases where the background task gets stuck - # Note: We don't call cleanup_background_task here because the task may - # still be running in the background. The cleanup will happen when the - # executor completes and sets a terminal status. - if poll_count > max_poll_count: - timeout_minutes = config.timeout_seconds // 60 - logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") - writer({"type": "task_timed_out", "task_id": task_id}) - return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" + # Polling timeout as a safety net (in case thread pool timeout doesn't work) + # Set to execution timeout + 60s buffer, in 5s poll intervals + # This catches edge cases where the background task gets stuck + # Note: We don't call cleanup_background_task here because the task may + # still be running in the background. The cleanup will happen when the + # executor completes and sets a terminal status. + if poll_count > max_poll_count: + timeout_minutes = config.timeout_seconds // 60 + logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") + writer({"type": "task_timed_out", "task_id": task_id}) + return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" + except asyncio.CancelledError: + async def cleanup_when_done() -> None: + max_cleanup_polls = max_poll_count + cleanup_poll_count = 0 + + while True: + result = get_background_task_result(task_id) + if result is None: + return + + if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None: + cleanup_background_task(task_id) + return + + if cleanup_poll_count > max_cleanup_polls: + logger.warning( + f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls" + ) + return + + await asyncio.sleep(5) + cleanup_poll_count += 1 + + def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None: + if cleanup_task.cancelled(): + return + + exc = cleanup_task.exception() + if exc is not None: + logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}") + + logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}") + asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure) + raise diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 19d1b82..bae7da1 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -1,10 +1,13 @@ """Core behavior tests for task tool orchestration.""" +import asyncio import importlib from enum import Enum from types import SimpleNamespace from unittest.mock import MagicMock +import pytest + from deerflow.subagents.config import SubagentConfig # Use module import so tests can patch the exact symbols referenced inside task_tool(). @@ -61,10 +64,23 @@ def _make_result( ) +def _run_task_tool(**kwargs): + return asyncio.run(task_tool_module.task_tool.coroutine(**kwargs)) + + +async def _no_sleep(_: float) -> None: + return None + + +class _DummyScheduledTask: + def add_done_callback(self, _callback): + return None + + def test_task_tool_returns_error_for_unknown_subagent(monkeypatch): monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None) - result = task_tool_module.task_tool.func( + result = _run_task_tool( runtime=None, description="执行任务", prompt="do work", @@ -109,11 +125,11 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch): monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix") monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses)) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) # task_tool lazily imports from deerflow.tools at call time, so patch that module-level function. monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=runtime, description="运行子任务", prompt="collect diagnostics", @@ -155,10 +171,10 @@ def test_task_tool_returns_failed_message(monkeypatch): lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"), ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=_make_runtime(), description="执行任务", prompt="do fail", @@ -189,10 +205,10 @@ def test_task_tool_returns_timed_out_message(monkeypatch): lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"), ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=_make_runtime(), description="执行任务", prompt="do timeout", @@ -225,10 +241,10 @@ def test_task_tool_polling_safety_timeout(monkeypatch): lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=_make_runtime(), description="执行任务", prompt="never finish", @@ -261,7 +277,7 @@ def test_cleanup_called_on_completed(monkeypatch): lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"), ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -269,7 +285,7 @@ def test_cleanup_called_on_completed(monkeypatch): lambda task_id: cleanup_calls.append(task_id), ) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=_make_runtime(), description="执行任务", prompt="complete task", @@ -301,7 +317,7 @@ def test_cleanup_called_on_failed(monkeypatch): lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"), ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -309,7 +325,7 @@ def test_cleanup_called_on_failed(monkeypatch): lambda task_id: cleanup_calls.append(task_id), ) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=_make_runtime(), description="执行任务", prompt="fail task", @@ -341,7 +357,7 @@ def test_cleanup_called_on_timed_out(monkeypatch): lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"), ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -349,7 +365,7 @@ def test_cleanup_called_on_timed_out(monkeypatch): lambda task_id: cleanup_calls.append(task_id), ) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=_make_runtime(), description="执行任务", prompt="timeout task", @@ -388,7 +404,7 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -396,7 +412,7 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): lambda task_id: cleanup_calls.append(task_id), ) - output = task_tool_module.task_tool.func( + output = _run_task_tool( runtime=_make_runtime(), description="执行任务", prompt="never finish", @@ -407,3 +423,117 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): assert output.startswith("Task polling timed out after 0 minutes") # cleanup should NOT be called because the task is still RUNNING assert cleanup_calls == [] + + +def test_cleanup_scheduled_on_cancellation(monkeypatch): + """Verify cancellation schedules deferred cleanup for the background task.""" + config = _make_subagent_config() + events = [] + cleanup_calls = [] + scheduled_cleanup_coros = [] + poll_count = 0 + + def get_result(_: str): + nonlocal poll_count + poll_count += 1 + if poll_count == 1: + return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]) + return _make_result(FakeSubagentStatus.COMPLETED, result="done") + + async def cancel_on_first_sleep(_: float) -> None: + raise asyncio.CancelledError + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "") + monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) + monkeypatch.setattr( + task_tool_module.asyncio, + "create_task", + lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(), + ) + monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) + monkeypatch.setattr( + task_tool_module, + "cleanup_background_task", + lambda task_id: cleanup_calls.append(task_id), + ) + + with pytest.raises(asyncio.CancelledError): + _run_task_tool( + runtime=_make_runtime(), + description="执行任务", + prompt="cancel task", + subagent_type="general-purpose", + tool_call_id="tc-cancelled-cleanup", + ) + + assert cleanup_calls == [] + assert len(scheduled_cleanup_coros) == 1 + + asyncio.run(scheduled_cleanup_coros.pop()) + + assert cleanup_calls == ["tc-cancelled-cleanup"] + + +def test_cancelled_cleanup_stops_after_timeout(monkeypatch): + """Verify deferred cleanup gives up after a bounded number of polls.""" + config = _make_subagent_config() + config.timeout_seconds = 1 + events = [] + cleanup_calls = [] + scheduled_cleanup_coros = [] + + async def cancel_on_first_sleep(_: float) -> None: + raise asyncio.CancelledError + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "") + monkeypatch.setattr( + task_tool_module, + "get_background_task_result", + lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), + ) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) + monkeypatch.setattr( + task_tool_module.asyncio, + "create_task", + lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(), + ) + monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) + monkeypatch.setattr( + task_tool_module, + "cleanup_background_task", + lambda task_id: cleanup_calls.append(task_id), + ) + + with pytest.raises(asyncio.CancelledError): + _run_task_tool( + runtime=_make_runtime(), + description="执行任务", + prompt="cancel task", + subagent_type="general-purpose", + tool_call_id="tc-cancelled-timeout", + ) + + async def bounded_sleep(_seconds: float) -> None: + return None + + monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep) + asyncio.run(scheduled_cleanup_coros.pop()) + + assert cleanup_calls == []