mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-25 15:04:46 +08:00
feat: support Podcast generation
This commit is contained in:
13
src/podcast/graph/audio_mixer_node.py
Normal file
13
src/podcast/graph/audio_mixer_node.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import logging
|
||||
|
||||
from src.podcast.graph.state import PodcastState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def audio_mixer_node(state: PodcastState):
|
||||
logger.info("Mixing audio chunks for podcast...")
|
||||
audio_chunks = state["audio_chunks"]
|
||||
combined_audio = b"".join(audio_chunks)
|
||||
logger.info("The podcast audio is now ready.")
|
||||
return {"output": combined_audio}
|
||||
35
src/podcast/graph/builder.py
Normal file
35
src/podcast/graph/builder.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
from .audio_mixer_node import audio_mixer_node
|
||||
from .script_writer_node import script_writer_node
|
||||
from .state import PodcastState
|
||||
from .tts_node import tts_node
|
||||
|
||||
|
||||
def build_graph():
|
||||
"""Build and return the podcast workflow graph."""
|
||||
# build state graph
|
||||
builder = StateGraph(PodcastState)
|
||||
builder.add_node("script_writer", script_writer_node)
|
||||
builder.add_node("tts", tts_node)
|
||||
builder.add_node("audio_mixer", audio_mixer_node)
|
||||
builder.add_edge(START, "script_writer")
|
||||
builder.add_edge("script_writer", "tts")
|
||||
builder.add_edge("tts", "audio_mixer")
|
||||
builder.add_edge("audio_mixer", END)
|
||||
return builder.compile()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
report_content = open("examples/nanjing_tangbao.md").read()
|
||||
workflow = build_graph()
|
||||
final_state = workflow.invoke({"input": report_content})
|
||||
for line in final_state["script"].lines:
|
||||
print("<M>" if line.speaker == "male" else "<F>", line.text)
|
||||
|
||||
with open("final.mp3", "wb") as f:
|
||||
f.write(final_state["output"])
|
||||
27
src/podcast/graph/script_writer_node.py
Normal file
27
src/podcast/graph/script_writer_node.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import logging
|
||||
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prompts.template import get_prompt_template
|
||||
|
||||
from ..types import Script
|
||||
from .state import PodcastState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def script_writer_node(state: PodcastState):
|
||||
logger.info("Generating script for podcast...")
|
||||
model = get_llm_by_type(
|
||||
AGENT_LLM_MAP["podcast_script_writer"]
|
||||
).with_structured_output(Script)
|
||||
script = model.invoke(
|
||||
[
|
||||
SystemMessage(content=get_prompt_template("podcast_script_writer")),
|
||||
HumanMessage(content=state["input"]),
|
||||
],
|
||||
)
|
||||
logging.info(script)
|
||||
return {"script": script, "audio_chunks": []}
|
||||
19
src/podcast/graph/state.py
Normal file
19
src/podcast/graph/state.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
from ..types import Script
|
||||
|
||||
|
||||
class PodcastState(MessagesState):
|
||||
"""State for the podcast generation."""
|
||||
|
||||
# Input
|
||||
input: str = ""
|
||||
|
||||
# Output
|
||||
output: Optional[bytes] = None
|
||||
|
||||
# Assets
|
||||
script: Optional[Script] = None
|
||||
audio_chunks: list[bytes] = []
|
||||
44
src/podcast/graph/tts_node.py
Normal file
44
src/podcast/graph/tts_node.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
|
||||
from src.podcast.graph.state import PodcastState
|
||||
from src.tools.tts import VolcengineTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tts_node(state: PodcastState):
|
||||
logger.info("Generating audio chunks for podcast...")
|
||||
tts_client = _create_tts_client()
|
||||
for line in state["script"].lines:
|
||||
tts_client.voice_type = (
|
||||
"BV002_streaming" if line.speaker == "male" else "BV001_streaming"
|
||||
)
|
||||
result = tts_client.text_to_speech(line.text, speed_ratio=1.1)
|
||||
if result["success"]:
|
||||
audio_data = result["audio_data"]
|
||||
audio_chunk = base64.b64decode(audio_data)
|
||||
state["audio_chunks"].append(audio_chunk)
|
||||
else:
|
||||
logger.error(result["error"])
|
||||
return {
|
||||
"audio_chunks": state["audio_chunks"],
|
||||
}
|
||||
|
||||
|
||||
def _create_tts_client():
|
||||
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
|
||||
if not app_id:
|
||||
raise Exception("VOLCENGINE_TTS_APPID is not set")
|
||||
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
|
||||
if not access_token:
|
||||
raise Exception("VOLCENGINE_TTS_ACCESS_TOKEN is not set")
|
||||
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
||||
voice_type = "BV001_streaming"
|
||||
return VolcengineTTS(
|
||||
appid=app_id,
|
||||
access_token=access_token,
|
||||
cluster=cluster,
|
||||
voice_type=voice_type,
|
||||
)
|
||||
Reference in New Issue
Block a user