mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-14 02:34:46 +08:00
feat: integrate volcengine tts functionality
This commit is contained in:
@@ -1,19 +1,22 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
from src.graph.builder import build_graph
|
||||
from src.server.chat_request import ChatMessage, ChatRequest
|
||||
from src.server.chat_request import ChatMessage, ChatRequest, TTSRequest
|
||||
from src.tools import VolcengineTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -137,3 +140,59 @@ def _make_event(event_type: str, data: dict[str, any]):
|
||||
if data.get("content") == "":
|
||||
data.pop("content")
|
||||
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@app.post("/api/tts")
|
||||
async def text_to_speech(request: TTSRequest):
|
||||
"""Convert text to speech using volcengine TTS API."""
|
||||
try:
|
||||
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
|
||||
if not app_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="VOLCENGINE_TTS_APPID is not set"
|
||||
)
|
||||
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
|
||||
)
|
||||
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
||||
voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
|
||||
|
||||
tts_client = VolcengineTTS(
|
||||
appid=app_id,
|
||||
access_token=access_token,
|
||||
cluster=cluster,
|
||||
voice_type=voice_type,
|
||||
)
|
||||
# Call the TTS API
|
||||
result = tts_client.text_to_speech(
|
||||
text=request.text[:1024],
|
||||
encoding=request.encoding,
|
||||
speed_ratio=request.speed_ratio,
|
||||
volume_ratio=request.volume_ratio,
|
||||
pitch_ratio=request.pitch_ratio,
|
||||
text_type=request.text_type,
|
||||
with_frontend=request.with_frontend,
|
||||
frontend_type=request.frontend_type,
|
||||
)
|
||||
|
||||
if not result["success"]:
|
||||
raise HTTPException(status_code=500, detail=str(result["error"]))
|
||||
|
||||
# Decode the base64 audio data
|
||||
audio_data = base64.b64decode(result["audio_data"])
|
||||
|
||||
# Return the audio file
|
||||
return Response(
|
||||
content=audio_data,
|
||||
media_type=f"audio/{request.encoding}",
|
||||
headers={
|
||||
"Content-Disposition": (
|
||||
f"attachment; filename=tts_output.{request.encoding}"
|
||||
)
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in TTS endpoint: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -44,3 +44,19 @@ class ChatRequest(BaseModel):
|
||||
interrupt_feedback: Optional[str] = Field(
|
||||
None, description="Interrupt feedback from the user on the plan"
|
||||
)
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str = Field(..., description="The text to convert to speech")
|
||||
voice_type: Optional[str] = Field(
|
||||
"BV700_V2_streaming", description="The voice type to use"
|
||||
)
|
||||
encoding: Optional[str] = Field("mp3", description="The audio encoding format")
|
||||
speed_ratio: Optional[float] = Field(1.0, description="Speech speed ratio")
|
||||
volume_ratio: Optional[float] = Field(1.0, description="Speech volume ratio")
|
||||
pitch_ratio: Optional[float] = Field(1.0, description="Speech pitch ratio")
|
||||
text_type: Optional[str] = Field("plain", description="Text type (plain or ssml)")
|
||||
with_frontend: Optional[int] = Field(
|
||||
1, description="Whether to use frontend processing"
|
||||
)
|
||||
frontend_type: Optional[str] = Field("unitTson", description="Frontend type")
|
||||
|
||||
Reference in New Issue
Block a user