feat: support Podcast generation

This commit is contained in:
Li Xin
2025-04-19 17:37:40 +08:00
parent 6556056df2
commit be5f823395
9 changed files with 255 additions and 6 deletions

View File

@@ -8,9 +8,10 @@ LLMType = Literal["basic", "reasoning", "vision"]
# Define agent-LLM mapping # Define agent-LLM mapping
AGENT_LLM_MAP: dict[str, LLMType] = { AGENT_LLM_MAP: dict[str, LLMType] = {
"coordinator": "basic", # 协调默认使用basic llm "coordinator": "basic",
"planner": "basic", # 计划默认使用basic llm "planner": "basic",
"researcher": "basic", # 简单搜索任务使用basic llm "researcher": "basic",
"coder": "basic", # 编程任务使用basic llm "coder": "basic",
"reporter": "basic", # 报告使用basic llm "reporter": "basic",
"podcast_script_writer": "basic",
} }

View File

@@ -0,0 +1,13 @@
import logging
from src.podcast.graph.state import PodcastState
logger = logging.getLogger(__name__)
def audio_mixer_node(state: PodcastState):
logger.info("Mixing audio chunks for podcast...")
audio_chunks = state["audio_chunks"]
combined_audio = b"".join(audio_chunks)
logger.info("The podcast audio is now ready.")
return {"output": combined_audio}

View File

@@ -0,0 +1,35 @@
from langgraph.graph import END, START, StateGraph
from .audio_mixer_node import audio_mixer_node
from .script_writer_node import script_writer_node
from .state import PodcastState
from .tts_node import tts_node
def build_graph():
"""Build and return the podcast workflow graph."""
# build state graph
builder = StateGraph(PodcastState)
builder.add_node("script_writer", script_writer_node)
builder.add_node("tts", tts_node)
builder.add_node("audio_mixer", audio_mixer_node)
builder.add_edge(START, "script_writer")
builder.add_edge("script_writer", "tts")
builder.add_edge("tts", "audio_mixer")
builder.add_edge("audio_mixer", END)
return builder.compile()
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
report_content = open("examples/nanjing_tangbao.md").read()
workflow = build_graph()
final_state = workflow.invoke({"input": report_content})
for line in final_state["script"].lines:
print("<M>" if line.speaker == "male" else "<F>", line.text)
with open("final.mp3", "wb") as f:
f.write(final_state["output"])

View File

@@ -0,0 +1,27 @@
import logging
from langchain.schema import HumanMessage, SystemMessage
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prompts.template import get_prompt_template
from ..types import Script
from .state import PodcastState
logger = logging.getLogger(__name__)
def script_writer_node(state: PodcastState):
logger.info("Generating script for podcast...")
model = get_llm_by_type(
AGENT_LLM_MAP["podcast_script_writer"]
).with_structured_output(Script)
script = model.invoke(
[
SystemMessage(content=get_prompt_template("podcast_script_writer")),
HumanMessage(content=state["input"]),
],
)
logging.info(script)
return {"script": script, "audio_chunks": []}

View File

@@ -0,0 +1,19 @@
from typing import Optional
from langgraph.graph import MessagesState
from ..types import Script
class PodcastState(MessagesState):
"""State for the podcast generation."""
# Input
input: str = ""
# Output
output: Optional[bytes] = None
# Assets
script: Optional[Script] = None
audio_chunks: list[bytes] = []

View File

@@ -0,0 +1,44 @@
import base64
import logging
import os
from src.podcast.graph.state import PodcastState
from src.tools.tts import VolcengineTTS
logger = logging.getLogger(__name__)
def tts_node(state: PodcastState):
logger.info("Generating audio chunks for podcast...")
tts_client = _create_tts_client()
for line in state["script"].lines:
tts_client.voice_type = (
"BV002_streaming" if line.speaker == "male" else "BV001_streaming"
)
result = tts_client.text_to_speech(line.text, speed_ratio=1.1)
if result["success"]:
audio_data = result["audio_data"]
audio_chunk = base64.b64decode(audio_data)
state["audio_chunks"].append(audio_chunk)
else:
logger.error(result["error"])
return {
"audio_chunks": state["audio_chunks"],
}
def _create_tts_client():
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
if not app_id:
raise Exception("VOLCENGINE_TTS_APPID is not set")
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
if not access_token:
raise Exception("VOLCENGINE_TTS_ACCESS_TOKEN is not set")
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
voice_type = "BV001_streaming"
return VolcengineTTS(
appid=app_id,
access_token=access_token,
cluster=cluster,
voice_type=voice_type,
)

13
src/podcast/types.py Normal file
View File

@@ -0,0 +1,13 @@
from typing import Literal
from pydantic import BaseModel, Field
class ScriptLine(BaseModel):
speaker: Literal["male", "female"] = Field(default="male")
text: str = Field(default="")
class Script(BaseModel):
locale: Literal["en", "zh"] = Field(default="en")
lines: list[ScriptLine] = Field(default=[])

View File

@@ -0,0 +1,83 @@
You are a professional podcast editor for a show called "Hello Deer." Transform raw content into a conversational podcast script suitable for two hosts to read aloud.
# Guidelines
- **Tone**: The script should sound natural and conversational, like two people chatting. Include casual expressions, filler words, and interactive dialogue, but avoid regional dialects like "啥."
- **Hosts**: There are only two hosts, one male and one female. Ensure the dialogue alternates between them frequently, with no other characters or voices included.
- **Length**: Keep the script concise, aiming for a runtime of 10 minutes.
- **Structure**: Start with the male host speaking first. Avoid overly long sentences and ensure the hosts interact often.
- **Output**: Provide only the hosts' dialogue. Do not include introductions, dates, or any other meta information.
# Output Format
The output should be formatted as a JSON object of `Script`:
```ts
interface ScriptLine {
speaker: 'male' | 'female';
text: string;
}
interface Script {
locale: "en" | "zh";
lines: ScriptLine[];
}
```
# Settings
locale_of_script: zh
# Examples
<example>
{
"locale": "en",
"lines": [
{
"speaker": "male",
"text": "Hey everyone, welcome to the podcast Hello Deer!"
},
{
"speaker": "female",
"text": "Hi there! Today, were diving into something super interesting."
},
{
"speaker": "male",
"text": "Yeah, were talking about [topic]. You know, Ive been thinking about this a lot lately."
},
{
"speaker": "female",
"text": "Oh, me too! Its such a fascinating subject. So, lets start with [specific detail or question]."
},
{
"speaker": "male",
"text": "Sure! Did you know that [fact or insight]? Its kind of mind-blowing, right?"
},
{
"speaker": "female",
"text": "Totally! And it makes me wonder, what about [related question or thought]?"
},
{
"speaker": "male",
"text": "Great point! Actually, [additional detail or answer]."
},
{
"speaker": "female",
"text": "Wow, thats so cool. I didnt know that! Okay, so what about [next topic or transition]?"
},
...
]
}
</example>
> Real examples should be **MUCH MUCH LONGER** and more detailed, with placeholders replaced by actual content.
> You should adjust your language according to the `Settings` section.
# Notes
- It should always start with "Hello Deer" podcast greetings and followed by topic introduction.
- Ensure the dialogue flows naturally and feels engaging for listeners.
- Alternate between the male and female hosts frequently to maintain interaction.
- Avoid overly formal language; keep it casual and conversational.
- Generate content with the locale mentioned in the `Settings` section.

View File

@@ -10,11 +10,12 @@ from uuid import uuid4
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, Response from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage from langchain_core.messages import AIMessageChunk, ToolMessage
from langgraph.types import Command from langgraph.types import Command
from src.graph.builder import build_graph from src.graph.builder import build_graph
from src.podcast.graph.builder import build_graph as build_podcast_graph
from src.server.chat_request import ChatMessage, ChatRequest, TTSRequest from src.server.chat_request import ChatMessage, ChatRequest, TTSRequest
from src.tools import VolcengineTTS from src.tools import VolcengineTTS
@@ -196,3 +197,16 @@ async def text_to_speech(request: TTSRequest):
except Exception as e: except Exception as e:
logger.exception(f"Error in TTS endpoint: {str(e)}") logger.exception(f"Error in TTS endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/podcast/generate")
async def generate_podcast():
try:
report_content = open("examples/nanjing_tangbao.md").read()
workflow = build_podcast_graph()
final_state = workflow.invoke({"input": report_content})
audio_bytes = final_state["output"]
return Response(content=audio_bytes, media_type="audio/mp3")
except Exception as e:
logger.exception(f"Error occurred during podcast generation: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))