diff --git a/src/server/app.py b/src/server/app.py index 110d5a3..9181079 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -308,13 +308,16 @@ def _create_event_stream_message( def _create_interrupt_event(thread_id, event_data): """Create interrupt event.""" + interrupt = event_data["__interrupt__"][0] + # Use the 'id' attribute (LangGraph 1.0+) instead of deprecated 'ns[0]' + interrupt_id = getattr(interrupt, "id", None) or thread_id return _make_event( "interrupt", { "thread_id": thread_id, - "id": event_data["__interrupt__"][0].ns[0], + "id": interrupt_id, "role": "assistant", - "content": event_data["__interrupt__"][0].value, + "content": interrupt.value, "finish_reason": "interrupt", "options": [ {"text": "Edit plan", "value": "edit_plan"}, @@ -461,7 +464,7 @@ async def _stream_graph_events( if "__interrupt__" in event_data: logger.debug( f"[{safe_thread_id}] Processing interrupt event: " - f"ns={getattr(event_data['__interrupt__'][0], 'ns', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, " + f"id={getattr(event_data['__interrupt__'][0], 'id', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, " f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}" ) yield _create_interrupt_event(thread_id, event_data) diff --git a/tests/unit/server/test_app.py b/tests/unit/server/test_app.py index 314de0c..3c63103 100644 --- a/tests/unit/server/test_app.py +++ b/tests/unit/server/test_app.py @@ -12,7 +12,12 @@ from langchain_core.messages import AIMessageChunk, ToolMessage from langgraph.types import Command from src.config.report_style import ReportStyle -from src.server.app import _astream_workflow_generator, _make_event, app +from src.server.app import ( + _astream_workflow_generator, + _create_interrupt_event, + _make_event, + app, +) @pytest.fixture @@ -657,9 +662,9 @@ class TestAstreamWorkflowGenerator: @pytest.mark.asyncio @patch("src.server.app.graph") async def test_astream_workflow_generator_interrupt_event(self, mock_graph): - # Mock interrupt data + # Mock interrupt data with the new 'id' attribute (LangGraph 1.0+) mock_interrupt = MagicMock() - mock_interrupt.ns = ["interrupt_id"] + mock_interrupt.id = "interrupt_id" mock_interrupt.value = "Plan requires approval" interrupt_data = {"__interrupt__": [mock_interrupt]} @@ -920,3 +925,94 @@ class TestGenerateProseEndpoint: response = client.post("/api/prose/generate", json=request_data) assert response.status_code == 500 assert response.json()["detail"] == "Internal Server Error" + + +class TestCreateInterruptEvent: + """Tests for _create_interrupt_event function (Issue #730 fix).""" + + def test_create_interrupt_event_with_id_attribute(self): + """Test that _create_interrupt_event works with LangGraph 1.0+ Interrupt objects that have 'id' attribute.""" + # Create a mock Interrupt object with the new 'id' attribute (LangGraph 1.0+) + mock_interrupt = MagicMock() + mock_interrupt.id = "interrupt-123" + mock_interrupt.value = "Please review the research plan" + + event_data = {"__interrupt__": [mock_interrupt]} + thread_id = "thread-456" + + result = _create_interrupt_event(thread_id, event_data) + + # Verify the result is a properly formatted SSE event + assert "event: interrupt\n" in result + assert '"thread_id": "thread-456"' in result + assert '"id": "interrupt-123"' in result + assert '"content": "Please review the research plan"' in result + assert '"finish_reason": "interrupt"' in result + assert '"role": "assistant"' in result + + def test_create_interrupt_event_fallback_to_thread_id(self): + """Test that _create_interrupt_event falls back to thread_id when 'id' attribute is None.""" + # Create a mock Interrupt object where id is None + mock_interrupt = MagicMock() + mock_interrupt.id = None + mock_interrupt.value = "Plan review needed" + + event_data = {"__interrupt__": [mock_interrupt]} + thread_id = "thread-789" + + result = _create_interrupt_event(thread_id, event_data) + + # Verify it falls back to thread_id + assert '"id": "thread-789"' in result + assert '"thread_id": "thread-789"' in result + assert '"content": "Plan review needed"' in result + + def test_create_interrupt_event_without_id_attribute(self): + """Test that _create_interrupt_event handles objects without 'id' attribute (backward compatibility).""" + # Create a mock object that doesn't have 'id' attribute at all + class MockInterrupt: + pass + mock_interrupt = MockInterrupt() + mock_interrupt.value = "Waiting for approval" + + event_data = {"__interrupt__": [mock_interrupt]} + thread_id = "thread-abc" + + result = _create_interrupt_event(thread_id, event_data) + + # Verify it falls back to thread_id when id attribute doesn't exist + assert '"id": "thread-abc"' in result + assert '"content": "Waiting for approval"' in result + + def test_create_interrupt_event_options(self): + """Test that _create_interrupt_event includes correct options.""" + mock_interrupt = MagicMock() + mock_interrupt.id = "int-001" + mock_interrupt.value = "Review plan" + + event_data = {"__interrupt__": [mock_interrupt]} + thread_id = "thread-xyz" + + result = _create_interrupt_event(thread_id, event_data) + + # Verify options are included + assert '"options":' in result + assert '"text": "Edit plan"' in result + assert '"value": "edit_plan"' in result + assert '"text": "Start research"' in result + assert '"value": "accepted"' in result + + def test_create_interrupt_event_with_complex_value(self): + """Test that _create_interrupt_event handles complex content values.""" + mock_interrupt = MagicMock() + mock_interrupt.id = "int-complex" + mock_interrupt.value = {"plan": "Research AI", "steps": ["step1", "step2"]} + + event_data = {"__interrupt__": [mock_interrupt]} + thread_id = "thread-complex" + + result = _create_interrupt_event(thread_id, event_data) + + # Verify complex value is included (will be serialized as JSON) + assert '"id": "int-complex"' in result + assert "Research AI" in result or "plan" in result