diff --git a/src/podcast/graph/script_writer_node.py b/src/podcast/graph/script_writer_node.py index 9ae1477..d5c0f08 100644 --- a/src/podcast/graph/script_writer_node.py +++ b/src/podcast/graph/script_writer_node.py @@ -1,13 +1,16 @@ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT +import json import logging +import openai from langchain_core.messages import HumanMessage, SystemMessage from src.config.agents import AGENT_LLM_MAP from src.llms.llm import get_llm_by_type from src.prompts.template import get_prompt_template +from src.utils.json_utils import repair_json_output from ..types import Script from .state import PodcastState @@ -17,14 +20,39 @@ logger = logging.getLogger(__name__) def script_writer_node(state: PodcastState): logger.info("Generating script for podcast...") - model = get_llm_by_type( - AGENT_LLM_MAP["podcast_script_writer"] - ).with_structured_output(Script, method="json_mode") - script = model.invoke( - [ - SystemMessage(content=get_prompt_template("podcast/podcast_script_writer")), - HumanMessage(content=state["input"]), - ], - ) - print(script) + base_model = get_llm_by_type(AGENT_LLM_MAP["podcast_script_writer"]) + + messages = [ + SystemMessage(content=get_prompt_template("podcast/podcast_script_writer")), + HumanMessage(content=state["input"]), + ] + + try: + # Try structured output with json_mode first + model = base_model.with_structured_output(Script, method="json_mode") + script = model.invoke(messages) + except openai.BadRequestError as e: + # Fall back for models that don't support json_object (e.g., Kimi K2) + if "json_object" in str(e).lower(): + logger.warning( + f"Model doesn't support json_mode, falling back to prompting: {e}" + ) + response = base_model.invoke(messages) + content = response.content if hasattr(response, "content") else str(response) + try: + repaired = repair_json_output(content) + script_dict = json.loads(repaired) + except json.JSONDecodeError as json_err: + logger.error( + "Failed to parse JSON from podcast script writer fallback " + "response: %s; content: %r", + json_err, + content, + ) + raise + script = Script.model_validate(script_dict) + else: + raise + + logger.debug("Generated podcast script: %s", script) return {"script": script, "audio_chunks": []} diff --git a/tests/unit/podcast/__init__.py b/tests/unit/podcast/__init__.py new file mode 100644 index 0000000..58bc29b --- /dev/null +++ b/tests/unit/podcast/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT diff --git a/tests/unit/podcast/test_script_writer_node.py b/tests/unit/podcast/test_script_writer_node.py new file mode 100644 index 0000000..df8eecf --- /dev/null +++ b/tests/unit/podcast/test_script_writer_node.py @@ -0,0 +1,214 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: MIT + +import json +from unittest.mock import MagicMock, patch + +import openai +import pytest + +from src.podcast.graph.script_writer_node import script_writer_node +from src.podcast.types import Script, ScriptLine + + +class TestScriptWriterNode: + """Tests for script_writer_node function.""" + + @pytest.fixture + def sample_state(self): + """Create a sample podcast state.""" + return {"input": "Test content for podcast generation"} + + @pytest.fixture + def sample_script(self): + """Create a sample Script object.""" + return Script( + locale="en", + lines=[ + ScriptLine(speaker="male", paragraph="Hello, welcome to our podcast."), + ScriptLine(speaker="female", paragraph="Today we discuss testing."), + ], + ) + + @pytest.fixture + def sample_script_json(self, sample_script): + """Create JSON representation of sample script.""" + return sample_script.model_dump_json() + + @patch("src.podcast.graph.script_writer_node.get_prompt_template") + @patch("src.podcast.graph.script_writer_node.get_llm_by_type") + def test_script_writer_with_json_mode_success( + self, mock_get_llm, mock_get_template, sample_state, sample_script + ): + """Test successful script generation using json_mode.""" + mock_get_template.return_value = "Generate a podcast script." + + mock_model = MagicMock() + mock_structured_model = MagicMock() + mock_model.with_structured_output.return_value = mock_structured_model + mock_structured_model.invoke.return_value = sample_script + mock_get_llm.return_value = mock_model + + result = script_writer_node(sample_state) + + assert result["script"] == sample_script + assert result["audio_chunks"] == [] + mock_model.with_structured_output.assert_called_once_with( + Script, method="json_mode" + ) + + @patch("src.podcast.graph.script_writer_node.get_prompt_template") + @patch("src.podcast.graph.script_writer_node.get_llm_by_type") + def test_script_writer_fallback_on_json_object_not_supported( + self, mock_get_llm, mock_get_template, sample_state, sample_script_json + ): + """Test fallback to prompting when model doesn't support json_object.""" + mock_get_template.return_value = "Generate a podcast script." + + mock_model = MagicMock() + mock_structured_model = MagicMock() + mock_model.with_structured_output.return_value = mock_structured_model + + # Simulate json_object not supported error + mock_structured_model.invoke.side_effect = openai.BadRequestError( + message="json_object is not supported by this model", + response=MagicMock(status_code=400), + body={"error": {"message": "json_object is not supported"}}, + ) + + # Mock the fallback response + mock_response = MagicMock() + mock_response.content = sample_script_json + mock_model.invoke.return_value = mock_response + + mock_get_llm.return_value = mock_model + + result = script_writer_node(sample_state) + + assert result["script"].locale == "en" + assert len(result["script"].lines) == 2 + assert result["audio_chunks"] == [] + # Verify fallback was used + mock_model.invoke.assert_called_once() + + @patch("src.podcast.graph.script_writer_node.get_prompt_template") + @patch("src.podcast.graph.script_writer_node.get_llm_by_type") + def test_script_writer_reraises_other_bad_request_errors( + self, mock_get_llm, mock_get_template, sample_state + ): + """Test that other BadRequestError types are re-raised.""" + mock_get_template.return_value = "Generate a podcast script." + + mock_model = MagicMock() + mock_structured_model = MagicMock() + mock_model.with_structured_output.return_value = mock_structured_model + + # Simulate a different BadRequestError (not json_object related) + mock_structured_model.invoke.side_effect = openai.BadRequestError( + message="Invalid model parameter", + response=MagicMock(status_code=400), + body={"error": {"message": "Invalid model parameter"}}, + ) + + mock_get_llm.return_value = mock_model + + with pytest.raises(openai.BadRequestError) as exc_info: + script_writer_node(sample_state) + + assert "Invalid model parameter" in str(exc_info.value) + + @patch("src.podcast.graph.script_writer_node.get_prompt_template") + @patch("src.podcast.graph.script_writer_node.get_llm_by_type") + def test_script_writer_fallback_with_markdown_wrapped_json( + self, mock_get_llm, mock_get_template, sample_state + ): + """Test fallback handles JSON wrapped in markdown code blocks.""" + mock_get_template.return_value = "Generate a podcast script." + + mock_model = MagicMock() + mock_structured_model = MagicMock() + mock_model.with_structured_output.return_value = mock_structured_model + + mock_structured_model.invoke.side_effect = openai.BadRequestError( + message="json_object is not supported", + response=MagicMock(status_code=400), + body={}, + ) + + # Mock response with markdown-wrapped JSON (common LLM output) + mock_response = MagicMock() + mock_response.content = """```json +{ + "locale": "zh", + "lines": [ + {"speaker": "male", "paragraph": "欢迎收听播客。"} + ] +} +```""" + mock_model.invoke.return_value = mock_response + + mock_get_llm.return_value = mock_model + + result = script_writer_node(sample_state) + + assert result["script"].locale == "zh" + assert len(result["script"].lines) == 1 + assert result["script"].lines[0].speaker == "male" + + @patch("src.podcast.graph.script_writer_node.get_prompt_template") + @patch("src.podcast.graph.script_writer_node.get_llm_by_type") + def test_script_writer_fallback_raises_on_invalid_json( + self, mock_get_llm, mock_get_template, sample_state + ): + """Test that fallback raises JSONDecodeError when response is not valid JSON.""" + mock_get_template.return_value = "Generate a podcast script." + + mock_model = MagicMock() + mock_structured_model = MagicMock() + mock_model.with_structured_output.return_value = mock_structured_model + + mock_structured_model.invoke.side_effect = openai.BadRequestError( + message="json_object is not supported", + response=MagicMock(status_code=400), + body={}, + ) + + # Mock response with completely invalid JSON + mock_response = MagicMock() + mock_response.content = "This is not JSON at all, just plain text response." + mock_model.invoke.return_value = mock_response + + mock_get_llm.return_value = mock_model + + with pytest.raises(json.JSONDecodeError): + script_writer_node(sample_state) + + @patch("src.podcast.graph.script_writer_node.get_prompt_template") + @patch("src.podcast.graph.script_writer_node.get_llm_by_type") + def test_script_writer_fallback_raises_on_invalid_schema( + self, mock_get_llm, mock_get_template, sample_state + ): + """Test that fallback raises ValidationError when JSON doesn't match Script schema.""" + mock_get_template.return_value = "Generate a podcast script." + + mock_model = MagicMock() + mock_structured_model = MagicMock() + mock_model.with_structured_output.return_value = mock_structured_model + + mock_structured_model.invoke.side_effect = openai.BadRequestError( + message="json_object is not supported", + response=MagicMock(status_code=400), + body={}, + ) + + # Mock response with valid JSON but invalid schema (missing required fields, wrong types) + mock_response = MagicMock() + mock_response.content = '{"locale": "invalid_locale", "lines": "not_a_list"}' + mock_model.invoke.return_value = mock_response + + mock_get_llm.return_value = mock_model + + # Pydantic ValidationError is raised when schema validation fails + from pydantic import ValidationError + with pytest.raises(ValidationError): + script_writer_node(sample_state)