diff --git a/src/server/app.py b/src/server/app.py index 2d24a2a..8fb908d 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -4,7 +4,7 @@ import base64 import json import logging -from typing import Annotated, List, cast +from typing import Annotated, Any, List, cast from uuid import uuid4 from fastapi import FastAPI, HTTPException, Query @@ -244,25 +244,36 @@ async def _stream_graph_events( graph_instance, workflow_input, workflow_config, thread_id ): """Stream events from the graph and process them.""" - async for agent, _, event_data in graph_instance.astream( - workflow_input, - config=workflow_config, - stream_mode=["messages", "updates"], - subgraphs=True, - ): - if isinstance(event_data, dict): - if "__interrupt__" in event_data: - yield _create_interrupt_event(thread_id, event_data) - continue + try: + async for agent, _, event_data in graph_instance.astream( + workflow_input, + config=workflow_config, + stream_mode=["messages", "updates"], + subgraphs=True, + ): + if isinstance(event_data, dict): + if "__interrupt__" in event_data: + yield _create_interrupt_event(thread_id, event_data) + continue - message_chunk, message_metadata = cast( - tuple[BaseMessage, dict[str, any]], event_data + message_chunk, message_metadata = cast( + tuple[BaseMessage, dict[str, Any]], event_data + ) + + async for event in _process_message_chunk( + message_chunk, message_metadata, thread_id, agent + ): + yield event + except Exception as e: + logger.exception("Error during graph execution") + yield _make_event( + "error", + { + "thread_id": thread_id, + "error": str(e), + }, ) - async for event in _process_message_chunk( - message_chunk, message_metadata, thread_id, agent - ): - yield event async def _astream_workflow_generator(