2025-04-17 11:34:42 +08:00
|
|
|
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
|
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
2025-04-18 15:28:31 +08:00
|
|
|
import base64
|
2025-04-13 21:14:31 +08:00
|
|
|
import json
|
|
|
|
|
import logging
|
2025-05-28 14:13:46 +08:00
|
|
|
from typing import Annotated, List, cast
|
2025-04-13 21:14:31 +08:00
|
|
|
from uuid import uuid4
|
|
|
|
|
|
2025-05-28 14:13:46 +08:00
|
|
|
from fastapi import FastAPI, HTTPException, Query
|
2025-04-13 21:14:31 +08:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-04-19 17:37:40 +08:00
|
|
|
from fastapi.responses import Response, StreamingResponse
|
2025-07-04 08:27:20 +08:00
|
|
|
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
2025-04-14 18:01:50 +08:00
|
|
|
from langgraph.types import Command
|
2025-08-16 21:03:12 +08:00
|
|
|
from langgraph.store.memory import InMemoryStore
|
|
|
|
|
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
|
|
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
|
|
|
from psycopg_pool import AsyncConnectionPool
|
2025-04-13 21:14:31 +08:00
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
from src.config.configuration import get_recursion_limit, get_bool_env, get_str_env
|
2025-06-07 20:48:39 +08:00
|
|
|
from src.config.report_style import ReportStyle
|
2025-05-28 14:13:46 +08:00
|
|
|
from src.config.tools import SELECTED_RAG_PROVIDER
|
2025-04-22 15:33:53 +08:00
|
|
|
from src.graph.builder import build_graph_with_memory
|
2025-07-04 08:27:20 +08:00
|
|
|
from src.llms.llm import get_configured_llm_models
|
2025-04-19 17:37:40 +08:00
|
|
|
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
2025-04-21 16:43:06 +08:00
|
|
|
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
2025-06-08 19:41:59 +08:00
|
|
|
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
|
2025-07-04 08:27:20 +08:00
|
|
|
from src.prose.graph.builder import build_graph as build_prose_graph
|
2025-05-28 14:13:46 +08:00
|
|
|
from src.rag.builder import build_retriever
|
|
|
|
|
from src.rag.retriever import Resource
|
2025-04-19 22:11:41 +08:00
|
|
|
from src.server.chat_request import (
|
|
|
|
|
ChatRequest,
|
2025-06-08 19:41:59 +08:00
|
|
|
EnhancePromptRequest,
|
2025-04-19 22:11:41 +08:00
|
|
|
GeneratePodcastRequest,
|
2025-04-21 16:43:06 +08:00
|
|
|
GeneratePPTRequest,
|
2025-04-26 23:12:13 +08:00
|
|
|
GenerateProseRequest,
|
2025-04-19 22:11:41 +08:00
|
|
|
TTSRequest,
|
|
|
|
|
)
|
2025-07-04 08:27:20 +08:00
|
|
|
from src.server.config_request import ConfigResponse
|
2025-04-23 14:38:04 +08:00
|
|
|
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
|
|
|
|
from src.server.mcp_utils import load_mcp_tools
|
2025-05-28 14:13:46 +08:00
|
|
|
from src.server.rag_request import (
|
|
|
|
|
RAGConfigResponse,
|
|
|
|
|
RAGResourceRequest,
|
|
|
|
|
RAGResourcesResponse,
|
|
|
|
|
)
|
2025-04-18 15:28:31 +08:00
|
|
|
from src.tools import VolcengineTTS
|
2025-08-16 21:03:12 +08:00
|
|
|
from src.graph.checkpoint import chat_stream_message
|
|
|
|
|
from src.utils.json_utils import sanitize_args
|
2025-04-13 21:14:31 +08:00
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2025-06-05 09:23:42 +08:00
|
|
|
INTERNAL_SERVER_ERROR_DETAIL = "Internal Server Error"
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
app = FastAPI(
|
2025-05-08 08:59:18 +08:00
|
|
|
title="DeerFlow API",
|
2025-04-17 11:17:03 +08:00
|
|
|
description="API for Deer",
|
2025-04-13 21:14:31 +08:00
|
|
|
version="0.1.0",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add CORS middleware
|
2025-07-18 18:04:03 +08:00
|
|
|
# It's recommended to load the allowed origins from an environment variable
|
|
|
|
|
# for better security and flexibility across different environments.
|
2025-08-16 21:03:12 +08:00
|
|
|
allowed_origins_str = get_str_env("ALLOWED_ORIGINS", "http://localhost:3000")
|
2025-07-18 18:04:03 +08:00
|
|
|
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")]
|
|
|
|
|
|
2025-07-20 11:38:18 +08:00
|
|
|
logger.info(f"Allowed origins: {allowed_origins}")
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
2025-07-18 18:04:03 +08:00
|
|
|
allow_origins=allowed_origins, # Restrict to specific origins
|
2025-04-13 21:14:31 +08:00
|
|
|
allow_credentials=True,
|
2025-07-20 11:38:18 +08:00
|
|
|
allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods
|
2025-07-20 14:10:46 +08:00
|
|
|
allow_headers=["*"], # Now allow all headers, but can be restricted further
|
2025-04-13 21:14:31 +08:00
|
|
|
)
|
2025-08-16 21:03:12 +08:00
|
|
|
in_memory_store = InMemoryStore()
|
2025-04-22 15:33:53 +08:00
|
|
|
graph = build_graph_with_memory()
|
2025-04-13 21:14:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/chat/stream")
|
|
|
|
|
async def chat_stream(request: ChatRequest):
|
2025-07-19 08:39:42 +08:00
|
|
|
# Check if MCP server configuration is enabled
|
2025-08-16 21:03:12 +08:00
|
|
|
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
|
2025-07-19 08:39:42 +08:00
|
|
|
|
|
|
|
|
# Validate MCP settings if provided
|
|
|
|
|
if request.mcp_settings and not mcp_enabled:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=403,
|
|
|
|
|
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
thread_id = request.thread_id
|
|
|
|
|
if thread_id == "__default__":
|
|
|
|
|
thread_id = str(uuid4())
|
2025-08-16 21:03:12 +08:00
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
return StreamingResponse(
|
|
|
|
|
_astream_workflow_generator(
|
|
|
|
|
request.model_dump()["messages"],
|
|
|
|
|
thread_id,
|
2025-05-28 14:13:46 +08:00
|
|
|
request.resources,
|
2025-04-13 21:14:31 +08:00
|
|
|
request.max_plan_iterations,
|
|
|
|
|
request.max_step_num,
|
2025-05-17 22:23:52 -07:00
|
|
|
request.max_search_results,
|
2025-04-14 18:01:50 +08:00
|
|
|
request.auto_accepted_plan,
|
2025-04-15 16:36:02 +08:00
|
|
|
request.interrupt_feedback,
|
2025-07-19 08:39:42 +08:00
|
|
|
request.mcp_settings if mcp_enabled else {},
|
2025-04-27 20:15:42 +08:00
|
|
|
request.enable_background_investigation,
|
2025-06-07 20:48:39 +08:00
|
|
|
request.report_style,
|
2025-06-14 13:12:43 +08:00
|
|
|
request.enable_deep_thinking,
|
2025-04-13 21:14:31 +08:00
|
|
|
),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-08-16 21:03:12 +08:00
|
|
|
def _process_tool_call_chunks(tool_call_chunks):
|
|
|
|
|
"""Process tool call chunks and sanitize arguments."""
|
|
|
|
|
chunks = []
|
|
|
|
|
for chunk in tool_call_chunks:
|
|
|
|
|
chunks.append(
|
|
|
|
|
{
|
|
|
|
|
"name": chunk.get("name", ""),
|
|
|
|
|
"args": sanitize_args(chunk.get("args", "")),
|
|
|
|
|
"id": chunk.get("id", ""),
|
|
|
|
|
"index": chunk.get("index", 0),
|
|
|
|
|
"type": chunk.get("type", ""),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_agent_name(agent, message_metadata):
|
|
|
|
|
"""Extract agent name from agent tuple."""
|
|
|
|
|
agent_name = "unknown"
|
|
|
|
|
if agent and len(agent) > 0:
|
|
|
|
|
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
|
|
|
|
|
else:
|
|
|
|
|
agent_name = message_metadata.get("langgraph_node", "unknown")
|
|
|
|
|
return agent_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_event_stream_message(
|
|
|
|
|
message_chunk, message_metadata, thread_id, agent_name
|
|
|
|
|
):
|
|
|
|
|
"""Create base event stream message."""
|
|
|
|
|
event_stream_message = {
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"agent": agent_name,
|
|
|
|
|
"id": message_chunk.id,
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"checkpoint_ns": message_metadata.get("checkpoint_ns", ""),
|
|
|
|
|
"langgraph_node": message_metadata.get("langgraph_node", ""),
|
|
|
|
|
"langgraph_path": message_metadata.get("langgraph_path", ""),
|
|
|
|
|
"langgraph_step": message_metadata.get("langgraph_step", ""),
|
|
|
|
|
"content": message_chunk.content,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Add optional fields
|
|
|
|
|
if message_chunk.additional_kwargs.get("reasoning_content"):
|
|
|
|
|
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
|
|
|
|
|
"reasoning_content"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if message_chunk.response_metadata.get("finish_reason"):
|
|
|
|
|
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
|
|
|
|
"finish_reason"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return event_stream_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_interrupt_event(thread_id, event_data):
|
|
|
|
|
"""Create interrupt event."""
|
|
|
|
|
return _make_event(
|
|
|
|
|
"interrupt",
|
|
|
|
|
{
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"id": event_data["__interrupt__"][0].ns[0],
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": event_data["__interrupt__"][0].value,
|
|
|
|
|
"finish_reason": "interrupt",
|
|
|
|
|
"options": [
|
|
|
|
|
{"text": "Edit plan", "value": "edit_plan"},
|
|
|
|
|
{"text": "Start research", "value": "accepted"},
|
|
|
|
|
],
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_initial_messages(message, thread_id):
|
|
|
|
|
"""Process initial messages and yield formatted events."""
|
|
|
|
|
json_data = json.dumps(
|
|
|
|
|
{
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"id": "run--" + message.get("id", uuid4().hex),
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": message.get("content", ""),
|
|
|
|
|
},
|
|
|
|
|
ensure_ascii=False,
|
|
|
|
|
separators=(",", ":"),
|
|
|
|
|
)
|
|
|
|
|
chat_stream_message(
|
|
|
|
|
thread_id, f"event: message_chunk\ndata: {json_data}\n\n", "none"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
|
|
|
|
|
"""Process a single message chunk and yield appropriate events."""
|
|
|
|
|
agent_name = _get_agent_name(agent, message_metadata)
|
|
|
|
|
event_stream_message = _create_event_stream_message(
|
|
|
|
|
message_chunk, message_metadata, thread_id, agent_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(message_chunk, ToolMessage):
|
|
|
|
|
# Tool Message - Return the result of the tool call
|
|
|
|
|
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
|
|
|
|
|
yield _make_event("tool_call_result", event_stream_message)
|
|
|
|
|
elif isinstance(message_chunk, AIMessageChunk):
|
|
|
|
|
# AI Message - Raw message tokens
|
|
|
|
|
if message_chunk.tool_calls:
|
|
|
|
|
# AI Message - Tool Call
|
|
|
|
|
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
|
|
|
|
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
|
|
|
|
message_chunk.tool_call_chunks
|
|
|
|
|
)
|
|
|
|
|
yield _make_event("tool_calls", event_stream_message)
|
|
|
|
|
elif message_chunk.tool_call_chunks:
|
|
|
|
|
# AI Message - Tool Call Chunks
|
|
|
|
|
event_stream_message["tool_call_chunks"] = _process_tool_call_chunks(
|
|
|
|
|
message_chunk.tool_call_chunks
|
|
|
|
|
)
|
|
|
|
|
yield _make_event("tool_call_chunks", event_stream_message)
|
|
|
|
|
else:
|
|
|
|
|
# AI Message - Raw message tokens
|
|
|
|
|
yield _make_event("message_chunk", event_stream_message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _stream_graph_events(
|
|
|
|
|
graph_instance, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
"""Stream events from the graph and process them."""
|
|
|
|
|
async for agent, _, event_data in graph_instance.astream(
|
|
|
|
|
workflow_input,
|
|
|
|
|
config=workflow_config,
|
|
|
|
|
stream_mode=["messages", "updates"],
|
|
|
|
|
subgraphs=True,
|
|
|
|
|
):
|
|
|
|
|
if isinstance(event_data, dict):
|
|
|
|
|
if "__interrupt__" in event_data:
|
|
|
|
|
yield _create_interrupt_event(thread_id, event_data)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
message_chunk, message_metadata = cast(
|
|
|
|
|
tuple[BaseMessage, dict[str, any]], event_data
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async for event in _process_message_chunk(
|
|
|
|
|
message_chunk, message_metadata, thread_id, agent
|
|
|
|
|
):
|
|
|
|
|
yield event
|
|
|
|
|
|
|
|
|
|
|
2025-04-13 21:14:31 +08:00
|
|
|
async def _astream_workflow_generator(
|
2025-06-11 11:10:02 +08:00
|
|
|
messages: List[dict],
|
2025-04-13 21:14:31 +08:00
|
|
|
thread_id: str,
|
2025-05-28 14:13:46 +08:00
|
|
|
resources: List[Resource],
|
2025-04-13 21:14:31 +08:00
|
|
|
max_plan_iterations: int,
|
|
|
|
|
max_step_num: int,
|
2025-05-17 22:23:52 -07:00
|
|
|
max_search_results: int,
|
2025-04-14 18:01:50 +08:00
|
|
|
auto_accepted_plan: bool,
|
2025-04-15 16:36:02 +08:00
|
|
|
interrupt_feedback: str,
|
2025-04-23 16:00:01 +08:00
|
|
|
mcp_settings: dict,
|
2025-06-07 20:48:39 +08:00
|
|
|
enable_background_investigation: bool,
|
|
|
|
|
report_style: ReportStyle,
|
2025-06-14 13:12:43 +08:00
|
|
|
enable_deep_thinking: bool,
|
2025-04-13 21:14:31 +08:00
|
|
|
):
|
2025-08-16 21:03:12 +08:00
|
|
|
# Process initial messages
|
|
|
|
|
for message in messages:
|
|
|
|
|
if isinstance(message, dict) and "content" in message:
|
|
|
|
|
_process_initial_messages(message, thread_id)
|
|
|
|
|
|
|
|
|
|
# Prepare workflow input
|
|
|
|
|
workflow_input = {
|
2025-04-21 20:16:08 +08:00
|
|
|
"messages": messages,
|
|
|
|
|
"plan_iterations": 0,
|
|
|
|
|
"final_report": "",
|
|
|
|
|
"current_plan": None,
|
|
|
|
|
"observations": [],
|
|
|
|
|
"auto_accepted_plan": auto_accepted_plan,
|
2025-04-27 20:15:42 +08:00
|
|
|
"enable_background_investigation": enable_background_investigation,
|
2025-06-11 11:10:02 +08:00
|
|
|
"research_topic": messages[-1]["content"] if messages else "",
|
2025-04-21 20:16:08 +08:00
|
|
|
}
|
2025-08-16 21:03:12 +08:00
|
|
|
|
2025-04-15 16:36:02 +08:00
|
|
|
if not auto_accepted_plan and interrupt_feedback:
|
|
|
|
|
resume_msg = f"[{interrupt_feedback}]"
|
|
|
|
|
if messages:
|
2025-04-27 20:15:42 +08:00
|
|
|
resume_msg += f" {messages[-1]['content']}"
|
2025-08-16 21:03:12 +08:00
|
|
|
workflow_input = Command(resume=resume_msg)
|
|
|
|
|
|
|
|
|
|
# Prepare workflow config
|
|
|
|
|
workflow_config = {
|
|
|
|
|
"thread_id": thread_id,
|
|
|
|
|
"resources": resources,
|
|
|
|
|
"max_plan_iterations": max_plan_iterations,
|
|
|
|
|
"max_step_num": max_step_num,
|
|
|
|
|
"max_search_results": max_search_results,
|
|
|
|
|
"mcp_settings": mcp_settings,
|
|
|
|
|
"report_style": report_style.value,
|
|
|
|
|
"enable_deep_thinking": enable_deep_thinking,
|
|
|
|
|
"recursion_limit": get_recursion_limit(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
|
|
|
|
|
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
|
|
|
|
|
# Handle checkpointer if configured
|
|
|
|
|
connection_kwargs = {
|
|
|
|
|
"autocommit": True,
|
|
|
|
|
"row_factory": "dict_row",
|
|
|
|
|
"prepare_threshold": 0,
|
|
|
|
|
}
|
|
|
|
|
if checkpoint_saver and checkpoint_url != "":
|
|
|
|
|
if checkpoint_url.startswith("postgresql://"):
|
|
|
|
|
logger.info("start async postgres checkpointer.")
|
|
|
|
|
async with AsyncConnectionPool(
|
|
|
|
|
checkpoint_url, kwargs=connection_kwargs
|
|
|
|
|
) as conn:
|
|
|
|
|
checkpointer = AsyncPostgresSaver(conn)
|
|
|
|
|
await checkpointer.setup()
|
|
|
|
|
graph.checkpointer = checkpointer
|
|
|
|
|
graph.store = in_memory_store
|
|
|
|
|
async for event in _stream_graph_events(
|
|
|
|
|
graph, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
yield event
|
|
|
|
|
|
|
|
|
|
if checkpoint_url.startswith("mongodb://"):
|
|
|
|
|
logger.info("start async mongodb checkpointer.")
|
|
|
|
|
async with AsyncMongoDBSaver.from_conn_string(
|
|
|
|
|
checkpoint_url
|
|
|
|
|
) as checkpointer:
|
|
|
|
|
graph.checkpointer = checkpointer
|
|
|
|
|
graph.store = in_memory_store
|
|
|
|
|
async for event in _stream_graph_events(
|
|
|
|
|
graph, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
yield event
|
|
|
|
|
else:
|
|
|
|
|
# Use graph without MongoDB checkpointer
|
|
|
|
|
async for event in _stream_graph_events(
|
|
|
|
|
graph, workflow_input, workflow_config, thread_id
|
|
|
|
|
):
|
|
|
|
|
yield event
|
2025-04-13 21:14:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _make_event(event_type: str, data: dict[str, any]):
|
|
|
|
|
if data.get("content") == "":
|
|
|
|
|
data.pop("content")
|
2025-08-16 21:03:12 +08:00
|
|
|
# Ensure JSON serialization with proper encoding
|
|
|
|
|
try:
|
|
|
|
|
json_data = json.dumps(data, ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
finish_reason = data.get("finish_reason", "")
|
|
|
|
|
chat_stream_message(
|
|
|
|
|
data.get("thread_id", ""),
|
|
|
|
|
f"event: {event_type}\ndata: {json_data}\n\n",
|
|
|
|
|
finish_reason,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return f"event: {event_type}\ndata: {json_data}\n\n"
|
|
|
|
|
except (TypeError, ValueError) as e:
|
|
|
|
|
logger.error(f"Error serializing event data: {e}")
|
|
|
|
|
# Return a safe error event
|
|
|
|
|
error_data = json.dumps({"error": "Serialization failed"}, ensure_ascii=False)
|
|
|
|
|
return f"event: error\ndata: {error_data}\n\n"
|
2025-04-18 15:28:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/tts")
|
|
|
|
|
async def text_to_speech(request: TTSRequest):
|
|
|
|
|
"""Convert text to speech using volcengine TTS API."""
|
2025-08-16 21:03:12 +08:00
|
|
|
app_id = get_str_env("VOLCENGINE_TTS_APPID", "")
|
2025-06-18 14:13:05 +08:00
|
|
|
if not app_id:
|
|
|
|
|
raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set")
|
2025-08-16 21:03:12 +08:00
|
|
|
access_token = get_str_env("VOLCENGINE_TTS_ACCESS_TOKEN", "")
|
2025-06-18 14:13:05 +08:00
|
|
|
if not access_token:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
|
|
|
|
|
)
|
|
|
|
|
|
2025-04-18 15:28:31 +08:00
|
|
|
try:
|
2025-08-16 21:03:12 +08:00
|
|
|
cluster = get_str_env("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
|
|
|
|
voice_type = get_str_env("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
|
2025-04-18 15:28:31 +08:00
|
|
|
|
|
|
|
|
tts_client = VolcengineTTS(
|
|
|
|
|
appid=app_id,
|
|
|
|
|
access_token=access_token,
|
|
|
|
|
cluster=cluster,
|
|
|
|
|
voice_type=voice_type,
|
|
|
|
|
)
|
|
|
|
|
# Call the TTS API
|
|
|
|
|
result = tts_client.text_to_speech(
|
|
|
|
|
text=request.text[:1024],
|
|
|
|
|
encoding=request.encoding,
|
|
|
|
|
speed_ratio=request.speed_ratio,
|
|
|
|
|
volume_ratio=request.volume_ratio,
|
|
|
|
|
pitch_ratio=request.pitch_ratio,
|
|
|
|
|
text_type=request.text_type,
|
|
|
|
|
with_frontend=request.with_frontend,
|
|
|
|
|
frontend_type=request.frontend_type,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not result["success"]:
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(result["error"]))
|
|
|
|
|
|
|
|
|
|
# Decode the base64 audio data
|
|
|
|
|
audio_data = base64.b64decode(result["audio_data"])
|
|
|
|
|
|
|
|
|
|
# Return the audio file
|
|
|
|
|
return Response(
|
|
|
|
|
content=audio_data,
|
|
|
|
|
media_type=f"audio/{request.encoding}",
|
|
|
|
|
headers={
|
|
|
|
|
"Content-Disposition": (
|
|
|
|
|
f"attachment; filename=tts_output.{request.encoding}"
|
|
|
|
|
)
|
|
|
|
|
},
|
|
|
|
|
)
|
2025-06-18 14:13:05 +08:00
|
|
|
|
2025-04-18 15:28:31 +08:00
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error in TTS endpoint: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-19 17:37:40 +08:00
|
|
|
|
|
|
|
|
|
2025-04-19 22:11:41 +08:00
|
|
|
@app.post("/api/podcast/generate")
|
|
|
|
|
async def generate_podcast(request: GeneratePodcastRequest):
|
2025-04-19 17:37:40 +08:00
|
|
|
try:
|
2025-04-19 22:11:41 +08:00
|
|
|
report_content = request.content
|
|
|
|
|
print(report_content)
|
2025-04-19 17:37:40 +08:00
|
|
|
workflow = build_podcast_graph()
|
|
|
|
|
final_state = workflow.invoke({"input": report_content})
|
|
|
|
|
audio_bytes = final_state["output"]
|
|
|
|
|
return Response(content=audio_bytes, media_type="audio/mp3")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during podcast generation: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-21 16:43:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/ppt/generate")
|
|
|
|
|
async def generate_ppt(request: GeneratePPTRequest):
|
|
|
|
|
try:
|
|
|
|
|
report_content = request.content
|
|
|
|
|
print(report_content)
|
|
|
|
|
workflow = build_ppt_graph()
|
|
|
|
|
final_state = workflow.invoke({"input": report_content})
|
|
|
|
|
generated_file_path = final_state["generated_file_path"]
|
|
|
|
|
with open(generated_file_path, "rb") as f:
|
|
|
|
|
ppt_bytes = f.read()
|
|
|
|
|
return Response(
|
|
|
|
|
content=ppt_bytes,
|
|
|
|
|
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during ppt generation: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-23 14:38:04 +08:00
|
|
|
|
|
|
|
|
|
2025-04-26 23:12:13 +08:00
|
|
|
@app.post("/api/prose/generate")
|
|
|
|
|
async def generate_prose(request: GenerateProseRequest):
|
|
|
|
|
try:
|
2025-06-03 11:50:54 +08:00
|
|
|
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
|
|
|
|
logger.info(f"Generating prose for prompt: {sanitized_prompt}")
|
2025-04-26 23:12:13 +08:00
|
|
|
workflow = build_prose_graph()
|
|
|
|
|
events = workflow.astream(
|
|
|
|
|
{
|
|
|
|
|
"content": request.prompt,
|
|
|
|
|
"option": request.option,
|
|
|
|
|
"command": request.command,
|
|
|
|
|
},
|
|
|
|
|
stream_mode="messages",
|
|
|
|
|
subgraphs=True,
|
|
|
|
|
)
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
(f"data: {event[0].content}\n\n" async for _, event in events),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during prose generation: {str(e)}")
|
2025-06-05 09:23:42 +08:00
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-04-26 23:12:13 +08:00
|
|
|
|
|
|
|
|
|
2025-06-08 19:41:59 +08:00
|
|
|
@app.post("/api/prompt/enhance")
|
|
|
|
|
async def enhance_prompt(request: EnhancePromptRequest):
|
|
|
|
|
try:
|
|
|
|
|
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
|
|
|
|
logger.info(f"Enhancing prompt: {sanitized_prompt}")
|
|
|
|
|
|
|
|
|
|
# Convert string report_style to ReportStyle enum
|
|
|
|
|
report_style = None
|
|
|
|
|
if request.report_style:
|
|
|
|
|
try:
|
|
|
|
|
# Handle both uppercase and lowercase input
|
|
|
|
|
style_mapping = {
|
|
|
|
|
"ACADEMIC": ReportStyle.ACADEMIC,
|
|
|
|
|
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
|
|
|
|
|
"NEWS": ReportStyle.NEWS,
|
|
|
|
|
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
|
|
|
|
|
}
|
|
|
|
|
report_style = style_mapping.get(
|
2025-07-04 08:27:20 +08:00
|
|
|
request.report_style.upper(), ReportStyle.ACADEMIC
|
2025-06-08 19:41:59 +08:00
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
# If invalid style, default to ACADEMIC
|
|
|
|
|
report_style = ReportStyle.ACADEMIC
|
|
|
|
|
else:
|
|
|
|
|
report_style = ReportStyle.ACADEMIC
|
|
|
|
|
|
|
|
|
|
workflow = build_prompt_enhancer_graph()
|
|
|
|
|
final_state = workflow.invoke(
|
|
|
|
|
{
|
|
|
|
|
"prompt": request.prompt,
|
|
|
|
|
"context": request.context,
|
|
|
|
|
"report_style": report_style,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return {"result": final_state["output"]}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception(f"Error occurred during prompt enhancement: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
|
|
|
|
|
2025-04-23 14:38:04 +08:00
|
|
|
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
|
|
|
|
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
|
|
|
|
"""Get information about an MCP server."""
|
2025-07-19 08:39:42 +08:00
|
|
|
# Check if MCP server configuration is enabled
|
2025-08-16 21:03:12 +08:00
|
|
|
if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False):
|
2025-07-19 08:39:42 +08:00
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=403,
|
2025-07-19 09:33:32 +08:00
|
|
|
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",
|
2025-07-19 08:39:42 +08:00
|
|
|
)
|
|
|
|
|
|
2025-04-23 14:38:04 +08:00
|
|
|
try:
|
2025-05-01 23:16:57 +08:00
|
|
|
# Set default timeout with a longer value for this endpoint
|
|
|
|
|
timeout = 300 # Default to 300 seconds for this endpoint
|
2025-05-08 08:59:18 +08:00
|
|
|
|
2025-05-01 23:16:57 +08:00
|
|
|
# Use custom timeout from request if provided
|
|
|
|
|
if request.timeout_seconds is not None:
|
|
|
|
|
timeout = request.timeout_seconds
|
2025-05-08 08:59:18 +08:00
|
|
|
|
2025-04-23 14:38:04 +08:00
|
|
|
# Load tools from the MCP server using the utility function
|
|
|
|
|
tools = await load_mcp_tools(
|
2025-04-23 16:00:01 +08:00
|
|
|
server_type=request.transport,
|
2025-04-23 14:38:04 +08:00
|
|
|
command=request.command,
|
|
|
|
|
args=request.args,
|
|
|
|
|
url=request.url,
|
|
|
|
|
env=request.env,
|
2025-08-20 17:23:57 +08:00
|
|
|
headers=request.headers,
|
2025-05-01 23:16:57 +08:00
|
|
|
timeout_seconds=timeout,
|
2025-04-23 14:38:04 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Create the response with tools
|
|
|
|
|
response = MCPServerMetadataResponse(
|
2025-04-23 16:00:01 +08:00
|
|
|
transport=request.transport,
|
2025-04-23 14:38:04 +08:00
|
|
|
command=request.command,
|
|
|
|
|
args=request.args,
|
|
|
|
|
url=request.url,
|
|
|
|
|
env=request.env,
|
2025-08-20 17:23:57 +08:00
|
|
|
headers=request.headers,
|
2025-04-23 14:38:04 +08:00
|
|
|
tools=tools,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
except Exception as e:
|
2025-06-18 14:13:05 +08:00
|
|
|
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
|
|
|
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
2025-05-28 14:13:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/rag/config", response_model=RAGConfigResponse)
|
|
|
|
|
async def rag_config():
|
|
|
|
|
"""Get the config of the RAG."""
|
|
|
|
|
return RAGConfigResponse(provider=SELECTED_RAG_PROVIDER)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/rag/resources", response_model=RAGResourcesResponse)
|
|
|
|
|
async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
|
|
|
|
|
"""Get the resources of the RAG."""
|
|
|
|
|
retriever = build_retriever()
|
|
|
|
|
if retriever:
|
|
|
|
|
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
|
|
|
|
|
return RAGResourcesResponse(resources=[])
|
2025-06-14 13:12:43 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/config", response_model=ConfigResponse)
|
|
|
|
|
async def config():
|
|
|
|
|
"""Get the config of the server."""
|
|
|
|
|
return ConfigResponse(
|
|
|
|
|
rag=RAGConfigResponse(provider=SELECTED_RAG_PROVIDER),
|
|
|
|
|
models=get_configured_llm_models(),
|
|
|
|
|
)
|