feat: support Podcast generation

This commit is contained in:
Li Xin
2025-04-19 17:37:40 +08:00
parent 6556056df2
commit be5f823395
9 changed files with 255 additions and 6 deletions

View File

@@ -0,0 +1,13 @@
import logging
from src.podcast.graph.state import PodcastState
logger = logging.getLogger(__name__)
def audio_mixer_node(state: PodcastState):
logger.info("Mixing audio chunks for podcast...")
audio_chunks = state["audio_chunks"]
combined_audio = b"".join(audio_chunks)
logger.info("The podcast audio is now ready.")
return {"output": combined_audio}

View File

@@ -0,0 +1,35 @@
from langgraph.graph import END, START, StateGraph
from .audio_mixer_node import audio_mixer_node
from .script_writer_node import script_writer_node
from .state import PodcastState
from .tts_node import tts_node
def build_graph():
"""Build and return the podcast workflow graph."""
# build state graph
builder = StateGraph(PodcastState)
builder.add_node("script_writer", script_writer_node)
builder.add_node("tts", tts_node)
builder.add_node("audio_mixer", audio_mixer_node)
builder.add_edge(START, "script_writer")
builder.add_edge("script_writer", "tts")
builder.add_edge("tts", "audio_mixer")
builder.add_edge("audio_mixer", END)
return builder.compile()
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
report_content = open("examples/nanjing_tangbao.md").read()
workflow = build_graph()
final_state = workflow.invoke({"input": report_content})
for line in final_state["script"].lines:
print("<M>" if line.speaker == "male" else "<F>", line.text)
with open("final.mp3", "wb") as f:
f.write(final_state["output"])

View File

@@ -0,0 +1,27 @@
import logging
from langchain.schema import HumanMessage, SystemMessage
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prompts.template import get_prompt_template
from ..types import Script
from .state import PodcastState
logger = logging.getLogger(__name__)
def script_writer_node(state: PodcastState):
logger.info("Generating script for podcast...")
model = get_llm_by_type(
AGENT_LLM_MAP["podcast_script_writer"]
).with_structured_output(Script)
script = model.invoke(
[
SystemMessage(content=get_prompt_template("podcast_script_writer")),
HumanMessage(content=state["input"]),
],
)
logging.info(script)
return {"script": script, "audio_chunks": []}

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langgraph.graph import MessagesState
from ..types import Script
class PodcastState(MessagesState):
"""State for the podcast generation."""
# Input
input: str = ""
# Output
output: Optional[bytes] = None
# Assets
script: Optional[Script] = None
audio_chunks: list[bytes] = []

View File

@@ -0,0 +1,44 @@
import base64
import logging
import os
from src.podcast.graph.state import PodcastState
from src.tools.tts import VolcengineTTS
logger = logging.getLogger(__name__)
def tts_node(state: PodcastState):
logger.info("Generating audio chunks for podcast...")
tts_client = _create_tts_client()
for line in state["script"].lines:
tts_client.voice_type = (
"BV002_streaming" if line.speaker == "male" else "BV001_streaming"
)
result = tts_client.text_to_speech(line.text, speed_ratio=1.1)
if result["success"]:
audio_data = result["audio_data"]
audio_chunk = base64.b64decode(audio_data)
state["audio_chunks"].append(audio_chunk)
else:
logger.error(result["error"])
return {
"audio_chunks": state["audio_chunks"],
}
def _create_tts_client():
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
if not app_id:
raise Exception("VOLCENGINE_TTS_APPID is not set")
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
if not access_token:
raise Exception("VOLCENGINE_TTS_ACCESS_TOKEN is not set")
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
voice_type = "BV001_streaming"
return VolcengineTTS(
appid=app_id,
access_token=access_token,
cluster=cluster,
voice_type=voice_type,
)