mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-02 22:02:13 +08:00
fix(task): avoid blocking in task tool polling (#1320)
* fix: avoid blocking in task tool polling * test: adapt task tool polling tests for async tool * fix: clean up cancelled task tool polling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
"""Task tool for delegating work to subagents."""
|
"""Task tool for delegating work to subagents."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@tool("task", parse_docstring=True)
|
@tool("task", parse_docstring=True)
|
||||||
def task_tool(
|
async def task_tool(
|
||||||
runtime: ToolRuntime[ContextT, ThreadState],
|
runtime: ToolRuntime[ContextT, ThreadState],
|
||||||
description: str,
|
description: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -129,67 +129,102 @@ def task_tool(
|
|||||||
# Send Task Started message'
|
# Send Task Started message'
|
||||||
writer({"type": "task_started", "task_id": task_id, "description": description})
|
writer({"type": "task_started", "task_id": task_id, "description": description})
|
||||||
|
|
||||||
while True:
|
try:
|
||||||
result = get_background_task_result(task_id)
|
while True:
|
||||||
|
result = get_background_task_result(task_id)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
|
logger.error(f"[trace={trace_id}] Task {task_id} not found in background tasks")
|
||||||
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
|
writer({"type": "task_failed", "task_id": task_id, "error": "Task disappeared from background tasks"})
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Error: Task {task_id} disappeared from background tasks"
|
return f"Error: Task {task_id} disappeared from background tasks"
|
||||||
|
|
||||||
# Log status changes for debugging
|
# Log status changes for debugging
|
||||||
if result.status != last_status:
|
if result.status != last_status:
|
||||||
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
logger.info(f"[trace={trace_id}] Task {task_id} status: {result.status.value}")
|
||||||
last_status = result.status
|
last_status = result.status
|
||||||
|
|
||||||
# Check for new AI messages and send task_running events
|
# Check for new AI messages and send task_running events
|
||||||
current_message_count = len(result.ai_messages)
|
current_message_count = len(result.ai_messages)
|
||||||
if current_message_count > last_message_count:
|
if current_message_count > last_message_count:
|
||||||
# Send task_running event for each new message
|
# Send task_running event for each new message
|
||||||
for i in range(last_message_count, current_message_count):
|
for i in range(last_message_count, current_message_count):
|
||||||
message = result.ai_messages[i]
|
message = result.ai_messages[i]
|
||||||
writer(
|
writer(
|
||||||
{
|
{
|
||||||
"type": "task_running",
|
"type": "task_running",
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"message": message,
|
"message": message,
|
||||||
"message_index": i + 1, # 1-based index for display
|
"message_index": i + 1, # 1-based index for display
|
||||||
"total_messages": current_message_count,
|
"total_messages": current_message_count,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
logger.info(f"[trace={trace_id}] Task {task_id} sent message #{i + 1}/{current_message_count}")
|
||||||
last_message_count = current_message_count
|
last_message_count = current_message_count
|
||||||
|
|
||||||
# Check if task completed, failed, or timed out
|
# Check if task completed, failed, or timed out
|
||||||
if result.status == SubagentStatus.COMPLETED:
|
if result.status == SubagentStatus.COMPLETED:
|
||||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task Succeeded. Result: {result.result}"
|
return f"Task Succeeded. Result: {result.result}"
|
||||||
elif result.status == SubagentStatus.FAILED:
|
elif result.status == SubagentStatus.FAILED:
|
||||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task failed. Error: {result.error}"
|
return f"Task failed. Error: {result.error}"
|
||||||
elif result.status == SubagentStatus.TIMED_OUT:
|
elif result.status == SubagentStatus.TIMED_OUT:
|
||||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task timed out. Error: {result.error}"
|
return f"Task timed out. Error: {result.error}"
|
||||||
|
|
||||||
# Still running, wait before next poll
|
# Still running, wait before next poll
|
||||||
time.sleep(5) # Poll every 5 seconds
|
await asyncio.sleep(5)
|
||||||
poll_count += 1
|
poll_count += 1
|
||||||
|
|
||||||
# Polling timeout as a safety net (in case thread pool timeout doesn't work)
|
# Polling timeout as a safety net (in case thread pool timeout doesn't work)
|
||||||
# Set to execution timeout + 60s buffer, in 5s poll intervals
|
# Set to execution timeout + 60s buffer, in 5s poll intervals
|
||||||
# This catches edge cases where the background task gets stuck
|
# This catches edge cases where the background task gets stuck
|
||||||
# Note: We don't call cleanup_background_task here because the task may
|
# Note: We don't call cleanup_background_task here because the task may
|
||||||
# still be running in the background. The cleanup will happen when the
|
# still be running in the background. The cleanup will happen when the
|
||||||
# executor completes and sets a terminal status.
|
# executor completes and sets a terminal status.
|
||||||
if poll_count > max_poll_count:
|
if poll_count > max_poll_count:
|
||||||
timeout_minutes = config.timeout_seconds // 60
|
timeout_minutes = config.timeout_seconds // 60
|
||||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||||
writer({"type": "task_timed_out", "task_id": task_id})
|
writer({"type": "task_timed_out", "task_id": task_id})
|
||||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
async def cleanup_when_done() -> None:
|
||||||
|
max_cleanup_polls = max_poll_count
|
||||||
|
cleanup_poll_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
result = get_background_task_result(task_id)
|
||||||
|
if result is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
|
||||||
|
cleanup_background_task(task_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if cleanup_poll_count > max_cleanup_polls:
|
||||||
|
logger.warning(
|
||||||
|
f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
cleanup_poll_count += 1
|
||||||
|
|
||||||
|
def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None:
|
||||||
|
if cleanup_task.cancelled():
|
||||||
|
return
|
||||||
|
|
||||||
|
exc = cleanup_task.exception()
|
||||||
|
if exc is not None:
|
||||||
|
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
||||||
|
|
||||||
|
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
||||||
|
asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure)
|
||||||
|
raise
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""Core behavior tests for task tool orchestration."""
|
"""Core behavior tests for task tool orchestration."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from deerflow.subagents.config import SubagentConfig
|
from deerflow.subagents.config import SubagentConfig
|
||||||
|
|
||||||
# Use module import so tests can patch the exact symbols referenced inside task_tool().
|
# Use module import so tests can patch the exact symbols referenced inside task_tool().
|
||||||
@@ -61,10 +64,23 @@ def _make_result(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_task_tool(**kwargs):
|
||||||
|
return asyncio.run(task_tool_module.task_tool.coroutine(**kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
async def _no_sleep(_: float) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyScheduledTask:
|
||||||
|
def add_done_callback(self, _callback):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
def test_task_tool_returns_error_for_unknown_subagent(monkeypatch):
|
||||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: None)
|
||||||
|
|
||||||
result = task_tool_module.task_tool.func(
|
result = _run_task_tool(
|
||||||
runtime=None,
|
runtime=None,
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="do work",
|
prompt="do work",
|
||||||
@@ -109,11 +125,11 @@ def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
|||||||
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
|
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "Skills Appendix")
|
||||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: next(responses))
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
|
# task_tool lazily imports from deerflow.tools at call time, so patch that module-level function.
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
description="运行子任务",
|
description="运行子任务",
|
||||||
prompt="collect diagnostics",
|
prompt="collect diagnostics",
|
||||||
@@ -155,10 +171,10 @@ def test_task_tool_returns_failed_message(monkeypatch):
|
|||||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
lambda _: _make_result(FakeSubagentStatus.FAILED, error="subagent crashed"),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=_make_runtime(),
|
runtime=_make_runtime(),
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="do fail",
|
prompt="do fail",
|
||||||
@@ -189,10 +205,10 @@ def test_task_tool_returns_timed_out_message(monkeypatch):
|
|||||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=_make_runtime(),
|
runtime=_make_runtime(),
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="do timeout",
|
prompt="do timeout",
|
||||||
@@ -225,10 +241,10 @@ def test_task_tool_polling_safety_timeout(monkeypatch):
|
|||||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=_make_runtime(),
|
runtime=_make_runtime(),
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="never finish",
|
prompt="never finish",
|
||||||
@@ -261,7 +277,7 @@ def test_cleanup_called_on_completed(monkeypatch):
|
|||||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -269,7 +285,7 @@ def test_cleanup_called_on_completed(monkeypatch):
|
|||||||
lambda task_id: cleanup_calls.append(task_id),
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=_make_runtime(),
|
runtime=_make_runtime(),
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="complete task",
|
prompt="complete task",
|
||||||
@@ -301,7 +317,7 @@ def test_cleanup_called_on_failed(monkeypatch):
|
|||||||
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
lambda _: _make_result(FakeSubagentStatus.FAILED, error="error"),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -309,7 +325,7 @@ def test_cleanup_called_on_failed(monkeypatch):
|
|||||||
lambda task_id: cleanup_calls.append(task_id),
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=_make_runtime(),
|
runtime=_make_runtime(),
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="fail task",
|
prompt="fail task",
|
||||||
@@ -341,7 +357,7 @@ def test_cleanup_called_on_timed_out(monkeypatch):
|
|||||||
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
lambda _: _make_result(FakeSubagentStatus.TIMED_OUT, error="timeout"),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -349,7 +365,7 @@ def test_cleanup_called_on_timed_out(monkeypatch):
|
|||||||
lambda task_id: cleanup_calls.append(task_id),
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=_make_runtime(),
|
runtime=_make_runtime(),
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="timeout task",
|
prompt="timeout task",
|
||||||
@@ -388,7 +404,7 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
|
|||||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.time, "sleep", lambda _: None)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -396,7 +412,7 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
|
|||||||
lambda task_id: cleanup_calls.append(task_id),
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = task_tool_module.task_tool.func(
|
output = _run_task_tool(
|
||||||
runtime=_make_runtime(),
|
runtime=_make_runtime(),
|
||||||
description="执行任务",
|
description="执行任务",
|
||||||
prompt="never finish",
|
prompt="never finish",
|
||||||
@@ -407,3 +423,117 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
|
|||||||
assert output.startswith("Task polling timed out after 0 minutes")
|
assert output.startswith("Task polling timed out after 0 minutes")
|
||||||
# cleanup should NOT be called because the task is still RUNNING
|
# cleanup should NOT be called because the task is still RUNNING
|
||||||
assert cleanup_calls == []
|
assert cleanup_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
||||||
|
"""Verify cancellation schedules deferred cleanup for the background task."""
|
||||||
|
config = _make_subagent_config()
|
||||||
|
events = []
|
||||||
|
cleanup_calls = []
|
||||||
|
scheduled_cleanup_coros = []
|
||||||
|
poll_count = 0
|
||||||
|
|
||||||
|
def get_result(_: str):
|
||||||
|
nonlocal poll_count
|
||||||
|
poll_count += 1
|
||||||
|
if poll_count == 1:
|
||||||
|
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
|
||||||
|
return _make_result(FakeSubagentStatus.COMPLETED, result="done")
|
||||||
|
|
||||||
|
async def cancel_on_first_sleep(_: float) -> None:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"SubagentExecutor",
|
||||||
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module.asyncio,
|
||||||
|
"create_task",
|
||||||
|
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"cleanup_background_task",
|
||||||
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
_run_task_tool(
|
||||||
|
runtime=_make_runtime(),
|
||||||
|
description="执行任务",
|
||||||
|
prompt="cancel task",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-cancelled-cleanup",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cleanup_calls == []
|
||||||
|
assert len(scheduled_cleanup_coros) == 1
|
||||||
|
|
||||||
|
asyncio.run(scheduled_cleanup_coros.pop())
|
||||||
|
|
||||||
|
assert cleanup_calls == ["tc-cancelled-cleanup"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
||||||
|
"""Verify deferred cleanup gives up after a bounded number of polls."""
|
||||||
|
config = _make_subagent_config()
|
||||||
|
config.timeout_seconds = 1
|
||||||
|
events = []
|
||||||
|
cleanup_calls = []
|
||||||
|
scheduled_cleanup_coros = []
|
||||||
|
|
||||||
|
async def cancel_on_first_sleep(_: float) -> None:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"SubagentExecutor",
|
||||||
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_skills_prompt_section", lambda: "")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"get_background_task_result",
|
||||||
|
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module.asyncio,
|
||||||
|
"create_task",
|
||||||
|
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"cleanup_background_task",
|
||||||
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
_run_task_tool(
|
||||||
|
runtime=_make_runtime(),
|
||||||
|
description="执行任务",
|
||||||
|
prompt="cancel task",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-cancelled-timeout",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bounded_sleep(_seconds: float) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep)
|
||||||
|
asyncio.run(scheduled_cleanup_coros.pop())
|
||||||
|
|
||||||
|
assert cleanup_calls == []
|
||||||
|
|||||||
Reference in New Issue
Block a user