mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-24 22:54:46 +08:00
* fix(server): graceful stream termination on cancellation (issue #847) * Update the code with review suggestion
This commit is contained in:
@@ -741,10 +741,19 @@ async def _stream_graph_events(
|
|||||||
|
|
||||||
logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}")
|
logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# User cancelled/interrupted the stream - this is normal, not an error
|
# User cancelled/interrupted the stream - this is normal, not an error.
|
||||||
|
# Do not re-raise: ending the generator gracefully lets FastAPI close the
|
||||||
|
# HTTP response properly so the client won't see "error decoding response body".
|
||||||
logger.info(f"[{safe_thread_id}] Graph event stream cancelled by user after {event_count} events")
|
logger.info(f"[{safe_thread_id}] Graph event stream cancelled by user after {event_count} events")
|
||||||
# Re-raise to signal cancellation properly without yielding an error event
|
try:
|
||||||
raise
|
yield _make_event("error", {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"error": "Stream cancelled",
|
||||||
|
"reason": "cancelled",
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass # Client likely already disconnected
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"[{safe_thread_id}] Error during graph execution")
|
logger.exception(f"[{safe_thread_id}] Error during graph execution")
|
||||||
yield _make_event(
|
yield _make_event(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-License-Identifier: MIT
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||||
@@ -17,6 +18,7 @@ from src.server.app import (
|
|||||||
_astream_workflow_generator,
|
_astream_workflow_generator,
|
||||||
_create_interrupt_event,
|
_create_interrupt_event,
|
||||||
_make_event,
|
_make_event,
|
||||||
|
_stream_graph_events,
|
||||||
app,
|
app,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,6 +55,57 @@ class TestMakeEvent:
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamGraphEventsCancellation:
|
||||||
|
"""Tests for graceful handling of asyncio.CancelledError in _stream_graph_events."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancelled_error_does_not_propagate(self):
|
||||||
|
"""When the stream is cancelled, the generator should end gracefully
|
||||||
|
instead of re-raising CancelledError (fixes issue #847)."""
|
||||||
|
|
||||||
|
async def _mock_astream(*args, **kwargs):
|
||||||
|
yield ("agent", None, {"some": "data"})
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
graph = MagicMock()
|
||||||
|
graph.astream = _mock_astream
|
||||||
|
|
||||||
|
events = []
|
||||||
|
# The generator must NOT raise CancelledError
|
||||||
|
async for event in _stream_graph_events(
|
||||||
|
graph, {"input": "test"}, {}, "test-thread-id"
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
# It should have yielded a final error event with reason='cancelled'
|
||||||
|
final_events_with_cancelled = [
|
||||||
|
e for e in events if '"reason": "cancelled"' in e
|
||||||
|
]
|
||||||
|
assert len(final_events_with_cancelled) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancelled_error_yields_cancelled_reason(self):
|
||||||
|
"""The final event should carry reason='cancelled' so the client
|
||||||
|
can distinguish cancellation from real errors."""
|
||||||
|
|
||||||
|
async def _mock_astream(*args, **kwargs):
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
yield # make this an async generator # noqa: E501
|
||||||
|
|
||||||
|
graph = MagicMock()
|
||||||
|
graph.astream = _mock_astream
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in _stream_graph_events(
|
||||||
|
graph, {"input": "test"}, {}, "test-thread-id"
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert '"reason": "cancelled"' in events[0]
|
||||||
|
assert '"error": "Stream cancelled"' in events[0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_astream_workflow_generator_preserves_clarification_history():
|
async def test_astream_workflow_generator_preserves_clarification_history():
|
||||||
messages = [
|
messages = [
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ export async function* chatStream(
|
|||||||
});
|
});
|
||||||
|
|
||||||
for await (const event of stream) {
|
for await (const event of stream) {
|
||||||
|
if (event.data == null) continue;
|
||||||
yield {
|
yield {
|
||||||
type: event.event,
|
type: event.event,
|
||||||
data: JSON.parse(event.data),
|
data: JSON.parse(event.data),
|
||||||
|
|||||||
@@ -84,10 +84,20 @@ export interface CitationsEvent {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ErrorEvent {
|
||||||
|
type: "error";
|
||||||
|
data: {
|
||||||
|
thread_id: string;
|
||||||
|
error: string;
|
||||||
|
reason?: "cancelled" | string;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export type ChatEvent =
|
export type ChatEvent =
|
||||||
| MessageChunkEvent
|
| MessageChunkEvent
|
||||||
| ToolCallsEvent
|
| ToolCallsEvent
|
||||||
| ToolCallChunksEvent
|
| ToolCallChunksEvent
|
||||||
| ToolCallResultEvent
|
| ToolCallResultEvent
|
||||||
| InterruptEvent
|
| InterruptEvent
|
||||||
| CitationsEvent;
|
| CitationsEvent
|
||||||
|
| ErrorEvent;
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ export function mergeMessage(message: Message, event: ChatEvent) {
|
|||||||
} else if (event.type === "interrupt") {
|
} else if (event.type === "interrupt") {
|
||||||
mergeInterruptMessage(message, event);
|
mergeInterruptMessage(message, event);
|
||||||
}
|
}
|
||||||
if (event.type !== "citations" && event.data.finish_reason) {
|
if (event.type !== "citations" && event.type !== "error" && event.data.finish_reason) {
|
||||||
message.finishReason = event.data.finish_reason;
|
message.finishReason = event.data.finish_reason;
|
||||||
message.isStreaming = false;
|
message.isStreaming = false;
|
||||||
if (message.toolCalls) {
|
if (message.toolCalls) {
|
||||||
|
|||||||
@@ -156,6 +156,13 @@ export async function sendMessage(
|
|||||||
const { type, data } = event;
|
const { type, data } = event;
|
||||||
let message: Message | undefined;
|
let message: Message | undefined;
|
||||||
|
|
||||||
|
if (type === "error") {
|
||||||
|
// Server sent an error event - check if it's user cancellation
|
||||||
|
if (data.reason !== "cancelled") {
|
||||||
|
toast(data.error || "An error occurred while generating the response.");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
// Handle citations event: store citations for the current research
|
// Handle citations event: store citations for the current research
|
||||||
if (type === "citations") {
|
if (type === "citations") {
|
||||||
const ongoingResearchId = useStore.getState().ongoingResearchId;
|
const ongoingResearchId = useStore.getState().ongoingResearchId;
|
||||||
@@ -207,10 +214,12 @@ export async function sendMessage(
|
|||||||
scheduleUpdate();
|
scheduleUpdate();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch {
|
} catch (error) {
|
||||||
toast("An error occurred while generating the response. Please try again.");
|
const isAborted = (error as Error).name === "AbortError";
|
||||||
|
if (!isAborted) {
|
||||||
|
toast("An error occurred while generating the response. Please try again.");
|
||||||
|
}
|
||||||
// Update message status.
|
// Update message status.
|
||||||
// TODO: const isAborted = (error as Error).name === "AbortError";
|
|
||||||
if (messageId != null) {
|
if (messageId != null) {
|
||||||
const message = getMessage(messageId);
|
const message = getMessage(messageId);
|
||||||
if (message?.isStreaming) {
|
if (message?.isStreaming) {
|
||||||
|
|||||||
Reference in New Issue
Block a user