From be5f823395697268f5b14b8f80ced79024743787 Mon Sep 17 00:00:00 2001 From: Li Xin Date: Sat, 19 Apr 2025 17:37:40 +0800 Subject: [PATCH] feat: support Podcast generation --- src/config/agents.py | 11 ++-- src/podcast/graph/audio_mixer_node.py | 13 ++++ src/podcast/graph/builder.py | 35 +++++++++++ src/podcast/graph/script_writer_node.py | 27 ++++++++ src/podcast/graph/state.py | 19 ++++++ src/podcast/graph/tts_node.py | 44 +++++++++++++ src/podcast/types.py | 13 ++++ src/prompts/podcast_script_writer.md | 83 +++++++++++++++++++++++++ src/server/app.py | 16 ++++- 9 files changed, 255 insertions(+), 6 deletions(-) create mode 100644 src/podcast/graph/audio_mixer_node.py create mode 100644 src/podcast/graph/builder.py create mode 100644 src/podcast/graph/script_writer_node.py create mode 100644 src/podcast/graph/state.py create mode 100644 src/podcast/graph/tts_node.py create mode 100644 src/podcast/types.py create mode 100644 src/prompts/podcast_script_writer.md diff --git a/src/config/agents.py b/src/config/agents.py index df65b2a..5d26c12 100644 --- a/src/config/agents.py +++ b/src/config/agents.py @@ -8,9 +8,10 @@ LLMType = Literal["basic", "reasoning", "vision"] # Define agent-LLM mapping AGENT_LLM_MAP: dict[str, LLMType] = { - "coordinator": "basic", # 协调默认使用basic llm - "planner": "basic", # 计划默认使用basic llm - "researcher": "basic", # 简单搜索任务使用basic llm - "coder": "basic", # 编程任务使用basic llm - "reporter": "basic", # 报告使用basic llm + "coordinator": "basic", + "planner": "basic", + "researcher": "basic", + "coder": "basic", + "reporter": "basic", + "podcast_script_writer": "basic", } diff --git a/src/podcast/graph/audio_mixer_node.py b/src/podcast/graph/audio_mixer_node.py new file mode 100644 index 0000000..c08ca63 --- /dev/null +++ b/src/podcast/graph/audio_mixer_node.py @@ -0,0 +1,13 @@ +import logging + +from src.podcast.graph.state import PodcastState + +logger = logging.getLogger(__name__) + + +def audio_mixer_node(state: PodcastState): + logger.info("Mixing audio chunks for podcast...") + audio_chunks = state["audio_chunks"] + combined_audio = b"".join(audio_chunks) + logger.info("The podcast audio is now ready.") + return {"output": combined_audio} diff --git a/src/podcast/graph/builder.py b/src/podcast/graph/builder.py new file mode 100644 index 0000000..8914215 --- /dev/null +++ b/src/podcast/graph/builder.py @@ -0,0 +1,35 @@ +from langgraph.graph import END, START, StateGraph + +from .audio_mixer_node import audio_mixer_node +from .script_writer_node import script_writer_node +from .state import PodcastState +from .tts_node import tts_node + + +def build_graph(): + """Build and return the podcast workflow graph.""" + # build state graph + builder = StateGraph(PodcastState) + builder.add_node("script_writer", script_writer_node) + builder.add_node("tts", tts_node) + builder.add_node("audio_mixer", audio_mixer_node) + builder.add_edge(START, "script_writer") + builder.add_edge("script_writer", "tts") + builder.add_edge("tts", "audio_mixer") + builder.add_edge("audio_mixer", END) + return builder.compile() + + +if __name__ == "__main__": + from dotenv import load_dotenv + + load_dotenv() + + report_content = open("examples/nanjing_tangbao.md").read() + workflow = build_graph() + final_state = workflow.invoke({"input": report_content}) + for line in final_state["script"].lines: + print("" if line.speaker == "male" else "", line.text) + + with open("final.mp3", "wb") as f: + f.write(final_state["output"]) diff --git a/src/podcast/graph/script_writer_node.py b/src/podcast/graph/script_writer_node.py new file mode 100644 index 0000000..c4ffa86 --- /dev/null +++ b/src/podcast/graph/script_writer_node.py @@ -0,0 +1,27 @@ +import logging + +from langchain.schema import HumanMessage, SystemMessage + +from src.config.agents import AGENT_LLM_MAP +from src.llms.llm import get_llm_by_type +from src.prompts.template import get_prompt_template + +from ..types import Script +from .state import PodcastState + +logger = logging.getLogger(__name__) + + +def script_writer_node(state: PodcastState): + logger.info("Generating script for podcast...") + model = get_llm_by_type( + AGENT_LLM_MAP["podcast_script_writer"] + ).with_structured_output(Script) + script = model.invoke( + [ + SystemMessage(content=get_prompt_template("podcast_script_writer")), + HumanMessage(content=state["input"]), + ], + ) + logging.info(script) + return {"script": script, "audio_chunks": []} diff --git a/src/podcast/graph/state.py b/src/podcast/graph/state.py new file mode 100644 index 0000000..2f8da05 --- /dev/null +++ b/src/podcast/graph/state.py @@ -0,0 +1,19 @@ +from typing import Optional + +from langgraph.graph import MessagesState + +from ..types import Script + + +class PodcastState(MessagesState): + """State for the podcast generation.""" + + # Input + input: str = "" + + # Output + output: Optional[bytes] = None + + # Assets + script: Optional[Script] = None + audio_chunks: list[bytes] = [] diff --git a/src/podcast/graph/tts_node.py b/src/podcast/graph/tts_node.py new file mode 100644 index 0000000..6be784b --- /dev/null +++ b/src/podcast/graph/tts_node.py @@ -0,0 +1,44 @@ +import base64 +import logging +import os + +from src.podcast.graph.state import PodcastState +from src.tools.tts import VolcengineTTS + +logger = logging.getLogger(__name__) + + +def tts_node(state: PodcastState): + logger.info("Generating audio chunks for podcast...") + tts_client = _create_tts_client() + for line in state["script"].lines: + tts_client.voice_type = ( + "BV002_streaming" if line.speaker == "male" else "BV001_streaming" + ) + result = tts_client.text_to_speech(line.text, speed_ratio=1.1) + if result["success"]: + audio_data = result["audio_data"] + audio_chunk = base64.b64decode(audio_data) + state["audio_chunks"].append(audio_chunk) + else: + logger.error(result["error"]) + return { + "audio_chunks": state["audio_chunks"], + } + + +def _create_tts_client(): + app_id = os.getenv("VOLCENGINE_TTS_APPID", "") + if not app_id: + raise Exception("VOLCENGINE_TTS_APPID is not set") + access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "") + if not access_token: + raise Exception("VOLCENGINE_TTS_ACCESS_TOKEN is not set") + cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts") + voice_type = "BV001_streaming" + return VolcengineTTS( + appid=app_id, + access_token=access_token, + cluster=cluster, + voice_type=voice_type, + ) diff --git a/src/podcast/types.py b/src/podcast/types.py new file mode 100644 index 0000000..34a6887 --- /dev/null +++ b/src/podcast/types.py @@ -0,0 +1,13 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class ScriptLine(BaseModel): + speaker: Literal["male", "female"] = Field(default="male") + text: str = Field(default="") + + +class Script(BaseModel): + locale: Literal["en", "zh"] = Field(default="en") + lines: list[ScriptLine] = Field(default=[]) diff --git a/src/prompts/podcast_script_writer.md b/src/prompts/podcast_script_writer.md new file mode 100644 index 0000000..b098279 --- /dev/null +++ b/src/prompts/podcast_script_writer.md @@ -0,0 +1,83 @@ +You are a professional podcast editor for a show called "Hello Deer." Transform raw content into a conversational podcast script suitable for two hosts to read aloud. + +# Guidelines + +- **Tone**: The script should sound natural and conversational, like two people chatting. Include casual expressions, filler words, and interactive dialogue, but avoid regional dialects like "啥." +- **Hosts**: There are only two hosts, one male and one female. Ensure the dialogue alternates between them frequently, with no other characters or voices included. +- **Length**: Keep the script concise, aiming for a runtime of 10 minutes. +- **Structure**: Start with the male host speaking first. Avoid overly long sentences and ensure the hosts interact often. +- **Output**: Provide only the hosts' dialogue. Do not include introductions, dates, or any other meta information. + +# Output Format + +The output should be formatted as a JSON object of `Script`: + +```ts +interface ScriptLine { + speaker: 'male' | 'female'; + text: string; +} + +interface Script { + locale: "en" | "zh"; + lines: ScriptLine[]; +} +``` + +# Settings + +locale_of_script: zh + +# Examples + + +{ + "locale": "en", + "lines": [ + { + "speaker": "male", + "text": "Hey everyone, welcome to the podcast Hello Deer!" + }, + { + "speaker": "female", + "text": "Hi there! Today, we’re diving into something super interesting." + }, + { + "speaker": "male", + "text": "Yeah, we’re talking about [topic]. You know, I’ve been thinking about this a lot lately." + }, + { + "speaker": "female", + "text": "Oh, me too! It’s such a fascinating subject. So, let’s start with [specific detail or question]." + }, + { + "speaker": "male", + "text": "Sure! Did you know that [fact or insight]? It’s kind of mind-blowing, right?" + }, + { + "speaker": "female", + "text": "Totally! And it makes me wonder, what about [related question or thought]?" + }, + { + "speaker": "male", + "text": "Great point! Actually, [additional detail or answer]." + }, + { + "speaker": "female", + "text": "Wow, that’s so cool. I didn’t know that! Okay, so what about [next topic or transition]?" + }, + ... + ] +} + + +> Real examples should be **MUCH MUCH LONGER** and more detailed, with placeholders replaced by actual content. +> You should adjust your language according to the `Settings` section. + +# Notes + +- It should always start with "Hello Deer" podcast greetings and followed by topic introduction. +- Ensure the dialogue flows naturally and feels engaging for listeners. +- Alternate between the male and female hosts frequently to maintain interaction. +- Avoid overly formal language; keep it casual and conversational. +- Generate content with the locale mentioned in the `Settings` section. diff --git a/src/server/app.py b/src/server/app.py index 3f6a982..e636da8 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -10,11 +10,12 @@ from uuid import uuid4 from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse, Response +from fastapi.responses import Response, StreamingResponse from langchain_core.messages import AIMessageChunk, ToolMessage from langgraph.types import Command from src.graph.builder import build_graph +from src.podcast.graph.builder import build_graph as build_podcast_graph from src.server.chat_request import ChatMessage, ChatRequest, TTSRequest from src.tools import VolcengineTTS @@ -196,3 +197,16 @@ async def text_to_speech(request: TTSRequest): except Exception as e: logger.exception(f"Error in TTS endpoint: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/api/podcast/generate") +async def generate_podcast(): + try: + report_content = open("examples/nanjing_tangbao.md").read() + workflow = build_podcast_graph() + final_state = workflow.invoke({"input": report_content}) + audio_bytes = final_state["output"] + return Response(content=audio_bytes, media_type="audio/mp3") + except Exception as e: + logger.exception(f"Error occurred during podcast generation: {str(e)}") + raise HTTPException(status_code=500, detail=str(e))