Files
deer-flow/src/server/app.py

199 lines
6.9 KiB
Python
Raw Normal View History

2025-04-17 11:34:42 +08:00
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import base64
2025-04-13 21:14:31 +08:00
import json
import logging
import os
2025-04-13 21:14:31 +08:00
from typing import List, cast
from uuid import uuid4
from fastapi import FastAPI, HTTPException
2025-04-13 21:14:31 +08:00
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, Response
2025-04-13 21:14:31 +08:00
from langchain_core.messages import AIMessageChunk, ToolMessage
2025-04-14 18:01:50 +08:00
from langgraph.types import Command
2025-04-13 21:14:31 +08:00
from src.graph.builder import build_graph
from src.server.chat_request import ChatMessage, ChatRequest, TTSRequest
from src.tools import VolcengineTTS
2025-04-13 21:14:31 +08:00
logger = logging.getLogger(__name__)
app = FastAPI(
2025-04-17 11:17:03 +08:00
title="Deer API",
description="API for Deer",
2025-04-13 21:14:31 +08:00
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
graph = build_graph()
@app.post("/api/chat/stream")
async def chat_stream(request: ChatRequest):
thread_id = request.thread_id
if thread_id == "__default__":
thread_id = str(uuid4())
return StreamingResponse(
_astream_workflow_generator(
request.model_dump()["messages"],
thread_id,
request.max_plan_iterations,
request.max_step_num,
2025-04-14 18:01:50 +08:00
request.auto_accepted_plan,
2025-04-15 16:36:02 +08:00
request.interrupt_feedback,
2025-04-13 21:14:31 +08:00
),
media_type="text/event-stream",
)
async def _astream_workflow_generator(
messages: List[ChatMessage],
thread_id: str,
max_plan_iterations: int,
max_step_num: int,
2025-04-14 18:01:50 +08:00
auto_accepted_plan: bool,
2025-04-15 16:36:02 +08:00
interrupt_feedback: str,
2025-04-13 21:14:31 +08:00
):
2025-04-14 18:01:50 +08:00
input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan}
2025-04-15 16:36:02 +08:00
if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]"
# add the last message to the resume message
if messages:
resume_msg += f" {messages[-1]["content"]}"
input_ = Command(resume=resume_msg)
2025-04-13 21:14:31 +08:00
async for agent, _, event_data in graph.astream(
2025-04-14 18:01:50 +08:00
input_,
2025-04-13 21:14:31 +08:00
config={
"thread_id": thread_id,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
},
2025-04-15 16:36:02 +08:00
stream_mode=["messages", "updates"],
2025-04-13 21:14:31 +08:00
subgraphs=True,
):
2025-04-15 16:36:02 +08:00
if isinstance(event_data, dict):
if "__interrupt__" in event_data:
yield _make_event(
"interrupt",
{
"thread_id": thread_id,
"id": event_data["__interrupt__"][0].ns[0],
"role": "assistant",
"content": event_data["__interrupt__"][0].value,
"finish_reason": "interrupt",
"options": [
2025-04-16 19:07:40 +08:00
{"text": "Edit plan", "value": "edit_plan"},
2025-04-17 08:58:52 +08:00
{"text": "Start research", "value": "accepted"},
2025-04-15 16:36:02 +08:00
],
},
)
continue
2025-04-13 21:14:31 +08:00
message_chunk, message_metadata = cast(
tuple[AIMessageChunk, dict[str, any]], event_data
)
event_stream_message: dict[str, any] = {
"thread_id": thread_id,
"agent": agent[0].split(":")[0],
"id": message_chunk.id,
"role": "assistant",
"content": message_chunk.content,
}
if message_chunk.response_metadata.get("finish_reason"):
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
"finish_reason"
)
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
yield _make_event("tool_call_result", event_stream_message)
else:
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens
yield _make_event("message_chunk", event_stream_message)
def _make_event(event_type: str, data: dict[str, any]):
if data.get("content") == "":
data.pop("content")
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
@app.post("/api/tts")
async def text_to_speech(request: TTSRequest):
"""Convert text to speech using volcengine TTS API."""
try:
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
if not app_id:
raise HTTPException(
status_code=400, detail="VOLCENGINE_TTS_APPID is not set"
)
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
if not access_token:
raise HTTPException(
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
)
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
tts_client = VolcengineTTS(
appid=app_id,
access_token=access_token,
cluster=cluster,
voice_type=voice_type,
)
# Call the TTS API
result = tts_client.text_to_speech(
text=request.text[:1024],
encoding=request.encoding,
speed_ratio=request.speed_ratio,
volume_ratio=request.volume_ratio,
pitch_ratio=request.pitch_ratio,
text_type=request.text_type,
with_frontend=request.with_frontend,
frontend_type=request.frontend_type,
)
if not result["success"]:
raise HTTPException(status_code=500, detail=str(result["error"]))
# Decode the base64 audio data
audio_data = base64.b64decode(result["audio_data"])
# Return the audio file
return Response(
content=audio_data,
media_type=f"audio/{request.encoding}",
headers={
"Content-Disposition": (
f"attachment; filename=tts_output.{request.encoding}"
)
},
)
except Exception as e:
logger.exception(f"Error in TTS endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))