Files
deer-flow/src/server/app.py

401 lines
14 KiB
Python
Raw Normal View History

2025-04-17 11:34:42 +08:00
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import base64
2025-04-13 21:14:31 +08:00
import json
import logging
import os
from typing import Annotated, List, cast
2025-04-13 21:14:31 +08:00
from uuid import uuid4
from fastapi import FastAPI, HTTPException, Query
2025-04-13 21:14:31 +08:00
from fastapi.middleware.cors import CORSMiddleware
2025-04-19 17:37:40 +08:00
from fastapi.responses import Response, StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage, BaseMessage
2025-04-14 18:01:50 +08:00
from langgraph.types import Command
2025-04-13 21:14:31 +08:00
from src.config.report_style import ReportStyle
from src.config.tools import SELECTED_RAG_PROVIDER
from src.graph.builder import build_graph_with_memory
2025-04-19 17:37:40 +08:00
from src.podcast.graph.builder import build_graph as build_podcast_graph
2025-04-21 16:43:06 +08:00
from src.ppt.graph.builder import build_graph as build_ppt_graph
2025-04-26 23:12:13 +08:00
from src.prose.graph.builder import build_graph as build_prose_graph
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
from src.rag.builder import build_retriever
from src.rag.retriever import Resource
2025-04-19 22:11:41 +08:00
from src.server.chat_request import (
ChatMessage,
ChatRequest,
EnhancePromptRequest,
2025-04-19 22:11:41 +08:00
GeneratePodcastRequest,
2025-04-21 16:43:06 +08:00
GeneratePPTRequest,
2025-04-26 23:12:13 +08:00
GenerateProseRequest,
2025-04-19 22:11:41 +08:00
TTSRequest,
)
2025-04-23 14:38:04 +08:00
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
from src.server.mcp_utils import load_mcp_tools
from src.server.rag_request import (
RAGConfigResponse,
RAGResourceRequest,
RAGResourcesResponse,
)
from src.tools import VolcengineTTS
2025-04-13 21:14:31 +08:00
logger = logging.getLogger(__name__)
INTERNAL_SERVER_ERROR_DETAIL = "Internal Server Error"
2025-04-13 21:14:31 +08:00
app = FastAPI(
2025-05-08 08:59:18 +08:00
title="DeerFlow API",
2025-04-17 11:17:03 +08:00
description="API for Deer",
2025-04-13 21:14:31 +08:00
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
graph = build_graph_with_memory()
2025-04-13 21:14:31 +08:00
@app.post("/api/chat/stream")
async def chat_stream(request: ChatRequest):
thread_id = request.thread_id
if thread_id == "__default__":
thread_id = str(uuid4())
return StreamingResponse(
_astream_workflow_generator(
request.model_dump()["messages"],
thread_id,
request.resources,
2025-04-13 21:14:31 +08:00
request.max_plan_iterations,
request.max_step_num,
request.max_search_results,
2025-04-14 18:01:50 +08:00
request.auto_accepted_plan,
2025-04-15 16:36:02 +08:00
request.interrupt_feedback,
2025-04-23 16:00:01 +08:00
request.mcp_settings,
request.enable_background_investigation,
request.report_style,
2025-04-13 21:14:31 +08:00
),
media_type="text/event-stream",
)
async def _astream_workflow_generator(
messages: List[ChatMessage],
thread_id: str,
resources: List[Resource],
2025-04-13 21:14:31 +08:00
max_plan_iterations: int,
max_step_num: int,
max_search_results: int,
2025-04-14 18:01:50 +08:00
auto_accepted_plan: bool,
2025-04-15 16:36:02 +08:00
interrupt_feedback: str,
2025-04-23 16:00:01 +08:00
mcp_settings: dict,
enable_background_investigation: bool,
report_style: ReportStyle,
2025-04-13 21:14:31 +08:00
):
2025-04-21 20:16:08 +08:00
input_ = {
"messages": messages,
"plan_iterations": 0,
"final_report": "",
"current_plan": None,
"observations": [],
"auto_accepted_plan": auto_accepted_plan,
"enable_background_investigation": enable_background_investigation,
2025-04-21 20:16:08 +08:00
}
2025-04-15 16:36:02 +08:00
if not auto_accepted_plan and interrupt_feedback:
resume_msg = f"[{interrupt_feedback}]"
# add the last message to the resume message
if messages:
resume_msg += f" {messages[-1]['content']}"
2025-04-15 16:36:02 +08:00
input_ = Command(resume=resume_msg)
2025-04-13 21:14:31 +08:00
async for agent, _, event_data in graph.astream(
2025-04-14 18:01:50 +08:00
input_,
2025-04-13 21:14:31 +08:00
config={
"thread_id": thread_id,
"resources": resources,
2025-04-13 21:14:31 +08:00
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
2025-04-23 16:00:01 +08:00
"mcp_settings": mcp_settings,
"report_style": report_style.value,
2025-04-13 21:14:31 +08:00
},
2025-04-15 16:36:02 +08:00
stream_mode=["messages", "updates"],
2025-04-13 21:14:31 +08:00
subgraphs=True,
):
2025-04-15 16:36:02 +08:00
if isinstance(event_data, dict):
if "__interrupt__" in event_data:
yield _make_event(
"interrupt",
{
"thread_id": thread_id,
"id": event_data["__interrupt__"][0].ns[0],
"role": "assistant",
"content": event_data["__interrupt__"][0].value,
"finish_reason": "interrupt",
"options": [
2025-04-16 19:07:40 +08:00
{"text": "Edit plan", "value": "edit_plan"},
2025-04-17 08:58:52 +08:00
{"text": "Start research", "value": "accepted"},
2025-04-15 16:36:02 +08:00
],
},
)
continue
2025-04-13 21:14:31 +08:00
message_chunk, message_metadata = cast(
tuple[BaseMessage, dict[str, any]], event_data
2025-04-13 21:14:31 +08:00
)
event_stream_message: dict[str, any] = {
"thread_id": thread_id,
"agent": agent[0].split(":")[0],
"id": message_chunk.id,
"role": "assistant",
"content": message_chunk.content,
}
if message_chunk.response_metadata.get("finish_reason"):
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
"finish_reason"
)
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
yield _make_event("tool_call_result", event_stream_message)
elif isinstance(message_chunk, AIMessageChunk):
2025-04-13 21:14:31 +08:00
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens
yield _make_event("message_chunk", event_stream_message)
def _make_event(event_type: str, data: dict[str, any]):
if data.get("content") == "":
data.pop("content")
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
@app.post("/api/tts")
async def text_to_speech(request: TTSRequest):
"""Convert text to speech using volcengine TTS API."""
try:
app_id = os.getenv("VOLCENGINE_TTS_APPID", "")
if not app_id:
raise HTTPException(
status_code=400, detail="VOLCENGINE_TTS_APPID is not set"
)
access_token = os.getenv("VOLCENGINE_TTS_ACCESS_TOKEN", "")
if not access_token:
raise HTTPException(
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
)
cluster = os.getenv("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
voice_type = os.getenv("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
tts_client = VolcengineTTS(
appid=app_id,
access_token=access_token,
cluster=cluster,
voice_type=voice_type,
)
# Call the TTS API
result = tts_client.text_to_speech(
text=request.text[:1024],
encoding=request.encoding,
speed_ratio=request.speed_ratio,
volume_ratio=request.volume_ratio,
pitch_ratio=request.pitch_ratio,
text_type=request.text_type,
with_frontend=request.with_frontend,
frontend_type=request.frontend_type,
)
if not result["success"]:
raise HTTPException(status_code=500, detail=str(result["error"]))
# Decode the base64 audio data
audio_data = base64.b64decode(result["audio_data"])
# Return the audio file
return Response(
content=audio_data,
media_type=f"audio/{request.encoding}",
headers={
"Content-Disposition": (
f"attachment; filename=tts_output.{request.encoding}"
)
},
)
except Exception as e:
logger.exception(f"Error in TTS endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
2025-04-19 17:37:40 +08:00
2025-04-19 22:11:41 +08:00
@app.post("/api/podcast/generate")
async def generate_podcast(request: GeneratePodcastRequest):
2025-04-19 17:37:40 +08:00
try:
2025-04-19 22:11:41 +08:00
report_content = request.content
print(report_content)
2025-04-19 17:37:40 +08:00
workflow = build_podcast_graph()
final_state = workflow.invoke({"input": report_content})
audio_bytes = final_state["output"]
return Response(content=audio_bytes, media_type="audio/mp3")
except Exception as e:
logger.exception(f"Error occurred during podcast generation: {str(e)}")
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
2025-04-21 16:43:06 +08:00
@app.post("/api/ppt/generate")
async def generate_ppt(request: GeneratePPTRequest):
try:
report_content = request.content
print(report_content)
workflow = build_ppt_graph()
final_state = workflow.invoke({"input": report_content})
generated_file_path = final_state["generated_file_path"]
with open(generated_file_path, "rb") as f:
ppt_bytes = f.read()
return Response(
content=ppt_bytes,
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
)
except Exception as e:
logger.exception(f"Error occurred during ppt generation: {str(e)}")
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
2025-04-23 14:38:04 +08:00
2025-04-26 23:12:13 +08:00
@app.post("/api/prose/generate")
async def generate_prose(request: GenerateProseRequest):
try:
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
logger.info(f"Generating prose for prompt: {sanitized_prompt}")
2025-04-26 23:12:13 +08:00
workflow = build_prose_graph()
events = workflow.astream(
{
"content": request.prompt,
"option": request.option,
"command": request.command,
},
stream_mode="messages",
subgraphs=True,
)
return StreamingResponse(
(f"data: {event[0].content}\n\n" async for _, event in events),
media_type="text/event-stream",
)
except Exception as e:
logger.exception(f"Error occurred during prose generation: {str(e)}")
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
2025-04-26 23:12:13 +08:00
@app.post("/api/prompt/enhance")
async def enhance_prompt(request: EnhancePromptRequest):
try:
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
logger.info(f"Enhancing prompt: {sanitized_prompt}")
# Convert string report_style to ReportStyle enum
report_style = None
if request.report_style:
try:
# Handle both uppercase and lowercase input
style_mapping = {
"ACADEMIC": ReportStyle.ACADEMIC,
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
"NEWS": ReportStyle.NEWS,
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
"academic": ReportStyle.ACADEMIC,
"popular_science": ReportStyle.POPULAR_SCIENCE,
"news": ReportStyle.NEWS,
"social_media": ReportStyle.SOCIAL_MEDIA,
}
report_style = style_mapping.get(
request.report_style, ReportStyle.ACADEMIC
)
except Exception:
# If invalid style, default to ACADEMIC
report_style = ReportStyle.ACADEMIC
else:
report_style = ReportStyle.ACADEMIC
workflow = build_prompt_enhancer_graph()
final_state = workflow.invoke(
{
"prompt": request.prompt,
"context": request.context,
"report_style": report_style,
}
)
return {"result": final_state["output"]}
except Exception as e:
logger.exception(f"Error occurred during prompt enhancement: {str(e)}")
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
2025-04-23 14:38:04 +08:00
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
async def mcp_server_metadata(request: MCPServerMetadataRequest):
"""Get information about an MCP server."""
try:
# Set default timeout with a longer value for this endpoint
timeout = 300 # Default to 300 seconds for this endpoint
2025-05-08 08:59:18 +08:00
# Use custom timeout from request if provided
if request.timeout_seconds is not None:
timeout = request.timeout_seconds
2025-05-08 08:59:18 +08:00
2025-04-23 14:38:04 +08:00
# Load tools from the MCP server using the utility function
tools = await load_mcp_tools(
2025-04-23 16:00:01 +08:00
server_type=request.transport,
2025-04-23 14:38:04 +08:00
command=request.command,
args=request.args,
url=request.url,
env=request.env,
timeout_seconds=timeout,
2025-04-23 14:38:04 +08:00
)
# Create the response with tools
response = MCPServerMetadataResponse(
2025-04-23 16:00:01 +08:00
transport=request.transport,
2025-04-23 14:38:04 +08:00
command=request.command,
args=request.args,
url=request.url,
env=request.env,
tools=tools,
)
return response
except Exception as e:
if not isinstance(e, HTTPException):
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
2025-04-23 14:38:04 +08:00
raise
@app.get("/api/rag/config", response_model=RAGConfigResponse)
async def rag_config():
"""Get the config of the RAG."""
return RAGConfigResponse(provider=SELECTED_RAG_PROVIDER)
@app.get("/api/rag/resources", response_model=RAGResourcesResponse)
async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
"""Get the resources of the RAG."""
retriever = build_retriever()
if retriever:
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
return RAGResourcesResponse(resources=[])