mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 14:22:13 +08:00
* feat: add resource upload support for RAG - Backend: Added ingest_file method to Retriever and MilvusRetriever - Backend: Added /api/rag/upload endpoint - Frontend: Added RAGTab in settings for uploading resources - Frontend: Updated translations and settings registration * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review of src/rag/milvus.py --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
979 lines
38 KiB
Python
979 lines
38 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Annotated, Any, List, Optional, cast
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI, HTTPException, Query, UploadFile
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
|
|
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
from langgraph.store.memory import InMemoryStore
|
|
from langgraph.types import Command
|
|
from psycopg_pool import AsyncConnectionPool
|
|
|
|
from src.config.configuration import get_recursion_limit
|
|
from src.config.loader import get_bool_env, get_str_env
|
|
from src.config.report_style import ReportStyle
|
|
from src.config.tools import SELECTED_RAG_PROVIDER
|
|
from src.graph.builder import build_graph_with_memory
|
|
from src.graph.checkpoint import chat_stream_message
|
|
from src.graph.utils import (
|
|
build_clarified_topic_from_history,
|
|
reconstruct_clarification_history,
|
|
)
|
|
from src.llms.llm import get_configured_llm_models
|
|
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
|
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
|
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
|
|
from src.prose.graph.builder import build_graph as build_prose_graph
|
|
from src.rag.builder import build_retriever
|
|
from src.rag.milvus import load_examples as load_milvus_examples
|
|
from src.rag.qdrant import load_examples as load_qdrant_examples
|
|
from src.rag.retriever import Resource
|
|
from src.server.chat_request import (
|
|
ChatRequest,
|
|
EnhancePromptRequest,
|
|
GeneratePodcastRequest,
|
|
GeneratePPTRequest,
|
|
GenerateProseRequest,
|
|
TTSRequest,
|
|
)
|
|
from src.server.config_request import ConfigResponse
|
|
from src.server.mcp_request import MCPServerMetadataRequest, MCPServerMetadataResponse
|
|
from src.server.mcp_utils import load_mcp_tools
|
|
from src.server.rag_request import (
|
|
RAGConfigResponse,
|
|
RAGResourceRequest,
|
|
RAGResourcesResponse,
|
|
)
|
|
from src.tools import VolcengineTTS
|
|
from src.utils.json_utils import sanitize_args
|
|
from src.utils.log_sanitizer import (
|
|
sanitize_agent_name,
|
|
sanitize_log_input,
|
|
sanitize_thread_id,
|
|
sanitize_tool_name,
|
|
sanitize_user_content,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configure Windows event loop policy for PostgreSQL compatibility
|
|
# On Windows, psycopg requires a selector-based event loop, not the default ProactorEventLoop
|
|
if os.name == "nt":
|
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
|
|
INTERNAL_SERVER_ERROR_DETAIL = "Internal Server Error"
|
|
|
|
app = FastAPI(
|
|
title="DeerFlow API",
|
|
description="API for Deer",
|
|
version="0.1.0",
|
|
)
|
|
|
|
# Add CORS middleware
|
|
# It's recommended to load the allowed origins from an environment variable
|
|
# for better security and flexibility across different environments.
|
|
allowed_origins_str = get_str_env("ALLOWED_ORIGINS", "http://localhost:3000")
|
|
allowed_origins = [origin.strip() for origin in allowed_origins_str.split(",")]
|
|
|
|
logger.info(f"Allowed origins: {allowed_origins}")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allowed_origins, # Restrict to specific origins
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "OPTIONS"], # Use the configured list of methods
|
|
allow_headers=["*"], # Now allow all headers, but can be restricted further
|
|
)
|
|
# Load examples into RAG providers if configured
|
|
load_milvus_examples()
|
|
load_qdrant_examples()
|
|
|
|
in_memory_store = InMemoryStore()
|
|
graph = build_graph_with_memory()
|
|
|
|
|
|
@app.post("/api/chat/stream")
|
|
async def chat_stream(request: ChatRequest):
|
|
# Check if MCP server configuration is enabled
|
|
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)
|
|
|
|
logger.debug(f"get the request locale : {request.locale}")
|
|
|
|
# Validate MCP settings if provided
|
|
if request.mcp_settings and not mcp_enabled:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",
|
|
)
|
|
|
|
thread_id = request.thread_id
|
|
if thread_id == "__default__":
|
|
thread_id = str(uuid4())
|
|
|
|
return StreamingResponse(
|
|
_astream_workflow_generator(
|
|
request.model_dump()["messages"],
|
|
thread_id,
|
|
request.resources,
|
|
request.max_plan_iterations,
|
|
request.max_step_num,
|
|
request.max_search_results,
|
|
request.auto_accepted_plan,
|
|
request.interrupt_feedback,
|
|
request.mcp_settings if mcp_enabled else {},
|
|
request.enable_background_investigation,
|
|
request.enable_web_search,
|
|
request.report_style,
|
|
request.enable_deep_thinking,
|
|
request.enable_clarification,
|
|
request.max_clarification_rounds,
|
|
request.locale,
|
|
request.interrupt_before_tools,
|
|
),
|
|
media_type="text/event-stream",
|
|
)
|
|
|
|
|
|
def _validate_tool_call_chunks(tool_call_chunks):
|
|
"""Validate and log tool call chunk structure for debugging."""
|
|
if not tool_call_chunks:
|
|
return
|
|
|
|
logger.debug(f"Validating tool_call_chunks: count={len(tool_call_chunks)}")
|
|
|
|
indices_seen = set()
|
|
tool_ids_seen = set()
|
|
|
|
for i, chunk in enumerate(tool_call_chunks):
|
|
index = chunk.get("index")
|
|
tool_id = chunk.get("id")
|
|
name = chunk.get("name", "")
|
|
has_args = "args" in chunk
|
|
|
|
logger.debug(
|
|
f"Chunk {i}: index={index}, id={tool_id}, name={name}, "
|
|
f"has_args={has_args}, type={chunk.get('type')}"
|
|
)
|
|
|
|
if index is not None:
|
|
indices_seen.add(index)
|
|
if tool_id:
|
|
tool_ids_seen.add(tool_id)
|
|
|
|
if len(indices_seen) > 1:
|
|
logger.debug(
|
|
f"Multiple indices detected: {sorted(indices_seen)} - "
|
|
f"This may indicate consecutive tool calls"
|
|
)
|
|
|
|
|
|
def _process_tool_call_chunks(tool_call_chunks):
|
|
"""
|
|
Process tool call chunks with proper index-based grouping.
|
|
|
|
This function handles the concatenation of tool call chunks that belong
|
|
to the same tool call (same index) while properly segregating chunks
|
|
from different tool calls (different indices).
|
|
|
|
The issue: In streaming, LangChain's ToolCallChunk concatenates string
|
|
attributes (name, args) when chunks have the same index. We need to:
|
|
1. Group chunks by index
|
|
2. Detect index collisions with different tool names
|
|
3. Accumulate arguments for the same index
|
|
4. Return properly segregated tool calls
|
|
"""
|
|
if not tool_call_chunks:
|
|
return []
|
|
|
|
_validate_tool_call_chunks(tool_call_chunks)
|
|
|
|
chunks = []
|
|
chunk_by_index = {} # Group chunks by index to handle streaming accumulation
|
|
|
|
for chunk in tool_call_chunks:
|
|
index = chunk.get("index")
|
|
chunk_id = chunk.get("id")
|
|
|
|
if index is not None:
|
|
# Create or update entry for this index
|
|
if index not in chunk_by_index:
|
|
chunk_by_index[index] = {
|
|
"name": "",
|
|
"args": "",
|
|
"id": chunk_id or "",
|
|
"index": index,
|
|
"type": chunk.get("type", ""),
|
|
}
|
|
|
|
# Validate and accumulate tool name
|
|
chunk_name = chunk.get("name", "")
|
|
if chunk_name:
|
|
stored_name = chunk_by_index[index]["name"]
|
|
|
|
# Check for index collision with different tool names
|
|
if stored_name and stored_name != chunk_name:
|
|
logger.warning(
|
|
f"Tool name mismatch detected at index {index}: "
|
|
f"'{stored_name}' != '{chunk_name}'. "
|
|
f"This may indicate a streaming artifact or consecutive tool calls "
|
|
f"with the same index assignment."
|
|
)
|
|
# Keep the first name to prevent concatenation
|
|
else:
|
|
chunk_by_index[index]["name"] = chunk_name
|
|
|
|
# Update ID if new one provided
|
|
if chunk_id and not chunk_by_index[index]["id"]:
|
|
chunk_by_index[index]["id"] = chunk_id
|
|
|
|
# Accumulate arguments
|
|
if chunk.get("args"):
|
|
chunk_by_index[index]["args"] += chunk.get("args", "")
|
|
else:
|
|
# Handle chunks without explicit index (edge case)
|
|
logger.debug(f"Chunk without index encountered: {chunk}")
|
|
chunks.append({
|
|
"name": chunk.get("name", ""),
|
|
"args": sanitize_args(chunk.get("args", "")),
|
|
"id": chunk.get("id", ""),
|
|
"index": 0,
|
|
"type": chunk.get("type", ""),
|
|
})
|
|
|
|
# Convert indexed chunks to list, sorted by index for proper order
|
|
for index in sorted(chunk_by_index.keys()):
|
|
chunk_data = chunk_by_index[index]
|
|
chunk_data["args"] = sanitize_args(chunk_data["args"])
|
|
chunks.append(chunk_data)
|
|
logger.debug(
|
|
f"Processed tool call: index={index}, name={chunk_data['name']}, "
|
|
f"id={chunk_data['id']}"
|
|
)
|
|
|
|
return chunks
|
|
|
|
|
|
def _get_agent_name(agent, message_metadata):
|
|
"""Extract agent name from agent tuple."""
|
|
agent_name = "unknown"
|
|
if agent and len(agent) > 0:
|
|
agent_name = agent[0].split(":")[0] if ":" in agent[0] else agent[0]
|
|
else:
|
|
agent_name = message_metadata.get("langgraph_node", "unknown")
|
|
return agent_name
|
|
|
|
|
|
def _create_event_stream_message(
|
|
message_chunk, message_metadata, thread_id, agent_name
|
|
):
|
|
"""Create base event stream message."""
|
|
content = message_chunk.content
|
|
if not isinstance(content, str):
|
|
content = json.dumps(content, ensure_ascii=False)
|
|
|
|
event_stream_message = {
|
|
"thread_id": thread_id,
|
|
"agent": agent_name,
|
|
"id": message_chunk.id,
|
|
"role": "assistant",
|
|
"checkpoint_ns": message_metadata.get("checkpoint_ns", ""),
|
|
"langgraph_node": message_metadata.get("langgraph_node", ""),
|
|
"langgraph_path": message_metadata.get("langgraph_path", ""),
|
|
"langgraph_step": message_metadata.get("langgraph_step", ""),
|
|
"content": content,
|
|
}
|
|
|
|
# Add optional fields
|
|
if message_chunk.additional_kwargs.get("reasoning_content"):
|
|
event_stream_message["reasoning_content"] = message_chunk.additional_kwargs[
|
|
"reasoning_content"
|
|
]
|
|
|
|
if message_chunk.response_metadata.get("finish_reason"):
|
|
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
|
"finish_reason"
|
|
)
|
|
|
|
return event_stream_message
|
|
|
|
|
|
def _create_interrupt_event(thread_id, event_data):
|
|
"""Create interrupt event."""
|
|
interrupt = event_data["__interrupt__"][0]
|
|
# Use the 'id' attribute (LangGraph 1.0+) instead of deprecated 'ns[0]'
|
|
interrupt_id = getattr(interrupt, "id", None) or thread_id
|
|
return _make_event(
|
|
"interrupt",
|
|
{
|
|
"thread_id": thread_id,
|
|
"id": interrupt_id,
|
|
"role": "assistant",
|
|
"content": interrupt.value,
|
|
"finish_reason": "interrupt",
|
|
"options": [
|
|
{"text": "Edit plan", "value": "edit_plan"},
|
|
{"text": "Start research", "value": "accepted"},
|
|
],
|
|
},
|
|
)
|
|
|
|
|
|
def _process_initial_messages(message, thread_id):
|
|
"""Process initial messages and yield formatted events."""
|
|
json_data = json.dumps(
|
|
{
|
|
"thread_id": thread_id,
|
|
"id": "run--" + message.get("id", uuid4().hex),
|
|
"role": "user",
|
|
"content": message.get("content", ""),
|
|
},
|
|
ensure_ascii=False,
|
|
separators=(",", ":"),
|
|
)
|
|
chat_stream_message(
|
|
thread_id, f"event: message_chunk\ndata: {json_data}\n\n", "none"
|
|
)
|
|
|
|
|
|
async def _process_message_chunk(message_chunk, message_metadata, thread_id, agent):
|
|
"""Process a single message chunk and yield appropriate events."""
|
|
|
|
agent_name = _get_agent_name(agent, message_metadata)
|
|
safe_agent_name = sanitize_agent_name(agent_name)
|
|
safe_thread_id = sanitize_thread_id(thread_id)
|
|
safe_agent = sanitize_agent_name(agent)
|
|
logger.debug(f"[{safe_thread_id}] _process_message_chunk started for agent={safe_agent_name}")
|
|
logger.debug(f"[{safe_thread_id}] Extracted agent_name: {safe_agent_name}")
|
|
|
|
event_stream_message = _create_event_stream_message(
|
|
message_chunk, message_metadata, thread_id, agent_name
|
|
)
|
|
|
|
if isinstance(message_chunk, ToolMessage):
|
|
# Tool Message - Return the result of the tool call
|
|
logger.debug(f"[{safe_thread_id}] Processing ToolMessage")
|
|
tool_call_id = message_chunk.tool_call_id
|
|
event_stream_message["tool_call_id"] = tool_call_id
|
|
|
|
# Validate tool_call_id for debugging
|
|
if tool_call_id:
|
|
safe_tool_id = sanitize_log_input(tool_call_id, max_length=100)
|
|
logger.debug(f"[{safe_thread_id}] ToolMessage with tool_call_id: {safe_tool_id}")
|
|
else:
|
|
logger.warning(f"[{safe_thread_id}] ToolMessage received without tool_call_id")
|
|
|
|
logger.debug(f"[{safe_thread_id}] Yielding tool_call_result event")
|
|
yield _make_event("tool_call_result", event_stream_message)
|
|
elif isinstance(message_chunk, AIMessageChunk):
|
|
# AI Message - Raw message tokens
|
|
has_tool_calls = bool(message_chunk.tool_calls)
|
|
has_chunks = bool(message_chunk.tool_call_chunks)
|
|
logger.debug(f"[{safe_thread_id}] Processing AIMessageChunk, tool_calls={has_tool_calls}, tool_call_chunks={has_chunks}")
|
|
|
|
if message_chunk.tool_calls:
|
|
# AI Message - Tool Call (complete tool calls)
|
|
safe_tool_names = [sanitize_tool_name(tc.get('name', 'unknown')) for tc in message_chunk.tool_calls]
|
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk has complete tool_calls: {safe_tool_names}")
|
|
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
|
|
|
# Process tool_call_chunks with proper index-based grouping
|
|
processed_chunks = _process_tool_call_chunks(
|
|
message_chunk.tool_call_chunks
|
|
)
|
|
if processed_chunks:
|
|
event_stream_message["tool_call_chunks"] = processed_chunks
|
|
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
|
logger.debug(
|
|
f"[{safe_thread_id}] Tool calls: {safe_tool_names}, "
|
|
f"Processed chunks: {len(processed_chunks)}"
|
|
)
|
|
|
|
logger.debug(f"[{safe_thread_id}] Yielding tool_calls event")
|
|
yield _make_event("tool_calls", event_stream_message)
|
|
elif message_chunk.tool_call_chunks:
|
|
# AI Message - Tool Call Chunks (streaming)
|
|
chunks_count = len(message_chunk.tool_call_chunks)
|
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk has streaming tool_call_chunks: {chunks_count} chunks")
|
|
processed_chunks = _process_tool_call_chunks(
|
|
message_chunk.tool_call_chunks
|
|
)
|
|
|
|
# Emit separate events for chunks with different indices (tool call boundaries)
|
|
if processed_chunks:
|
|
prev_chunk = None
|
|
for chunk in processed_chunks:
|
|
current_index = chunk.get("index")
|
|
|
|
# Log index transitions to detect tool call boundaries
|
|
if prev_chunk is not None and current_index != prev_chunk.get("index"):
|
|
prev_name = sanitize_tool_name(prev_chunk.get('name'))
|
|
curr_name = sanitize_tool_name(chunk.get('name'))
|
|
logger.debug(
|
|
f"[{safe_thread_id}] Tool call boundary detected: "
|
|
f"index {prev_chunk.get('index')} ({prev_name}) -> "
|
|
f"{current_index} ({curr_name})"
|
|
)
|
|
|
|
prev_chunk = chunk
|
|
|
|
# Include all processed chunks in the event
|
|
event_stream_message["tool_call_chunks"] = processed_chunks
|
|
safe_chunk_names = [sanitize_tool_name(c.get('name')) for c in processed_chunks]
|
|
logger.debug(
|
|
f"[{safe_thread_id}] Streamed {len(processed_chunks)} tool call chunk(s): "
|
|
f"{safe_chunk_names}"
|
|
)
|
|
|
|
logger.debug(f"[{safe_thread_id}] Yielding tool_call_chunks event")
|
|
yield _make_event("tool_call_chunks", event_stream_message)
|
|
else:
|
|
# AI Message - Raw message tokens
|
|
content_len = len(message_chunk.content) if isinstance(message_chunk.content, str) else 0
|
|
logger.debug(f"[{safe_thread_id}] AIMessageChunk is raw message tokens, content_len={content_len}")
|
|
yield _make_event("message_chunk", event_stream_message)
|
|
|
|
|
|
async def _stream_graph_events(
|
|
graph_instance, workflow_input, workflow_config, thread_id
|
|
):
|
|
"""Stream events from the graph and process them."""
|
|
safe_thread_id = sanitize_thread_id(thread_id)
|
|
logger.debug(f"[{safe_thread_id}] Starting graph event stream with agent nodes")
|
|
try:
|
|
event_count = 0
|
|
async for agent, _, event_data in graph_instance.astream(
|
|
workflow_input,
|
|
config=workflow_config,
|
|
stream_mode=["messages", "updates"],
|
|
subgraphs=True,
|
|
):
|
|
event_count += 1
|
|
safe_agent = sanitize_agent_name(agent)
|
|
logger.debug(f"[{safe_thread_id}] Graph event #{event_count} received from agent: {safe_agent}")
|
|
|
|
if isinstance(event_data, dict):
|
|
if "__interrupt__" in event_data:
|
|
logger.debug(
|
|
f"[{safe_thread_id}] Processing interrupt event: "
|
|
f"id={getattr(event_data['__interrupt__'][0], 'id', 'unknown') if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 else 'unknown'}, "
|
|
f"value_len={len(getattr(event_data['__interrupt__'][0], 'value', '')) if isinstance(event_data['__interrupt__'], (list, tuple)) and len(event_data['__interrupt__']) > 0 and hasattr(event_data['__interrupt__'][0], 'value') and hasattr(event_data['__interrupt__'][0].value, '__len__') else 'unknown'}"
|
|
)
|
|
yield _create_interrupt_event(thread_id, event_data)
|
|
logger.debug(f"[{safe_thread_id}] Dict event without interrupt, skipping")
|
|
continue
|
|
|
|
message_chunk, message_metadata = cast(
|
|
tuple[BaseMessage, dict[str, Any]], event_data
|
|
)
|
|
|
|
safe_node = sanitize_agent_name(message_metadata.get('langgraph_node', 'unknown'))
|
|
safe_step = sanitize_log_input(message_metadata.get('langgraph_step', 'unknown'))
|
|
logger.debug(
|
|
f"[{safe_thread_id}] Processing message chunk: "
|
|
f"type={type(message_chunk).__name__}, "
|
|
f"node={safe_node}, "
|
|
f"step={safe_step}"
|
|
)
|
|
|
|
async for event in _process_message_chunk(
|
|
message_chunk, message_metadata, thread_id, agent
|
|
):
|
|
yield event
|
|
|
|
logger.debug(f"[{safe_thread_id}] Graph event stream completed. Total events: {event_count}")
|
|
except asyncio.CancelledError:
|
|
# User cancelled/interrupted the stream - this is normal, not an error
|
|
logger.info(f"[{safe_thread_id}] Graph event stream cancelled by user after {event_count} events")
|
|
# Re-raise to signal cancellation properly without yielding an error event
|
|
raise
|
|
except Exception as e:
|
|
logger.exception(f"[{safe_thread_id}] Error during graph execution")
|
|
yield _make_event(
|
|
"error",
|
|
{
|
|
"thread_id": thread_id,
|
|
"error": "Error during graph execution",
|
|
},
|
|
)
|
|
|
|
|
|
async def _astream_workflow_generator(
|
|
messages: List[dict],
|
|
thread_id: str,
|
|
resources: List[Resource],
|
|
max_plan_iterations: int,
|
|
max_step_num: int,
|
|
max_search_results: int,
|
|
auto_accepted_plan: bool,
|
|
interrupt_feedback: str,
|
|
mcp_settings: dict,
|
|
enable_background_investigation: bool,
|
|
enable_web_search: bool,
|
|
report_style: ReportStyle,
|
|
enable_deep_thinking: bool,
|
|
enable_clarification: bool,
|
|
max_clarification_rounds: int,
|
|
locale: str = "en-US",
|
|
interrupt_before_tools: Optional[List[str]] = None,
|
|
):
|
|
safe_thread_id = sanitize_thread_id(thread_id)
|
|
safe_feedback = sanitize_log_input(interrupt_feedback) if interrupt_feedback else ""
|
|
logger.debug(
|
|
f"[{safe_thread_id}] _astream_workflow_generator starting: "
|
|
f"messages_count={len(messages)}, "
|
|
f"auto_accepted_plan={auto_accepted_plan}, "
|
|
f"interrupt_feedback={safe_feedback}, "
|
|
f"interrupt_before_tools={interrupt_before_tools}"
|
|
)
|
|
|
|
# Process initial messages
|
|
logger.debug(f"[{safe_thread_id}] Processing {len(messages)} initial messages")
|
|
for message in messages:
|
|
if isinstance(message, dict) and "content" in message:
|
|
safe_content = sanitize_user_content(message.get('content', ''))
|
|
logger.debug(f"[{safe_thread_id}] Sending initial message to client: {safe_content}")
|
|
_process_initial_messages(message, thread_id)
|
|
|
|
logger.debug(f"[{safe_thread_id}] Reconstructing clarification history")
|
|
clarification_history = reconstruct_clarification_history(messages)
|
|
|
|
logger.debug(f"[{safe_thread_id}] Building clarified topic from history")
|
|
clarified_topic, clarification_history = build_clarified_topic_from_history(
|
|
clarification_history
|
|
)
|
|
latest_message_content = messages[-1]["content"] if messages else ""
|
|
clarified_research_topic = clarified_topic or latest_message_content
|
|
safe_topic = sanitize_user_content(clarified_research_topic)
|
|
logger.debug(f"[{safe_thread_id}] Clarified research topic: {safe_topic}")
|
|
|
|
# Prepare workflow input
|
|
logger.debug(f"[{safe_thread_id}] Preparing workflow input")
|
|
workflow_input = {
|
|
"messages": messages,
|
|
"plan_iterations": 0,
|
|
"final_report": "",
|
|
"current_plan": None,
|
|
"observations": [],
|
|
"auto_accepted_plan": auto_accepted_plan,
|
|
"enable_background_investigation": enable_background_investigation,
|
|
"research_topic": latest_message_content,
|
|
"clarification_history": clarification_history,
|
|
"clarified_research_topic": clarified_research_topic,
|
|
"enable_clarification": enable_clarification,
|
|
"max_clarification_rounds": max_clarification_rounds,
|
|
"locale": locale,
|
|
}
|
|
|
|
if not auto_accepted_plan and interrupt_feedback:
|
|
logger.debug(f"[{safe_thread_id}] Creating resume command with interrupt_feedback: {safe_feedback}")
|
|
resume_msg = f"[{interrupt_feedback}]"
|
|
if messages:
|
|
resume_msg += f" {messages[-1]['content']}"
|
|
workflow_input = Command(resume=resume_msg)
|
|
|
|
# Prepare workflow config
|
|
logger.debug(
|
|
f"[{safe_thread_id}] Preparing workflow config: "
|
|
f"max_plan_iterations={max_plan_iterations}, "
|
|
f"max_step_num={max_step_num}, "
|
|
f"report_style={report_style.value}, "
|
|
f"enable_deep_thinking={enable_deep_thinking}"
|
|
)
|
|
workflow_config = {
|
|
"thread_id": thread_id,
|
|
"resources": resources,
|
|
"max_plan_iterations": max_plan_iterations,
|
|
"max_step_num": max_step_num,
|
|
"max_search_results": max_search_results,
|
|
"mcp_settings": mcp_settings,
|
|
"enable_web_search": enable_web_search,
|
|
"report_style": report_style.value,
|
|
"enable_deep_thinking": enable_deep_thinking,
|
|
"interrupt_before_tools": interrupt_before_tools,
|
|
"recursion_limit": get_recursion_limit(),
|
|
}
|
|
|
|
checkpoint_saver = get_bool_env("LANGGRAPH_CHECKPOINT_SAVER", False)
|
|
checkpoint_url = get_str_env("LANGGRAPH_CHECKPOINT_DB_URL", "")
|
|
|
|
logger.debug(
|
|
f"[{safe_thread_id}] Checkpoint configuration: "
|
|
f"saver_enabled={checkpoint_saver}, "
|
|
f"url_configured={bool(checkpoint_url)}"
|
|
)
|
|
|
|
# Handle checkpointer if configured
|
|
connection_kwargs = {
|
|
"autocommit": True,
|
|
"row_factory": "dict_row",
|
|
"prepare_threshold": 0,
|
|
}
|
|
if checkpoint_saver and checkpoint_url != "":
|
|
if checkpoint_url.startswith("postgresql://"):
|
|
logger.info(f"[{safe_thread_id}] Starting async postgres checkpointer")
|
|
logger.debug(f"[{safe_thread_id}] Setting up PostgreSQL connection pool")
|
|
async with AsyncConnectionPool(
|
|
checkpoint_url, kwargs=connection_kwargs
|
|
) as conn:
|
|
logger.debug(f"[{safe_thread_id}] Initializing AsyncPostgresSaver")
|
|
checkpointer = AsyncPostgresSaver(conn)
|
|
await checkpointer.setup()
|
|
logger.debug(f"[{safe_thread_id}] Attaching checkpointer to graph")
|
|
graph.checkpointer = checkpointer
|
|
graph.store = in_memory_store
|
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
|
async for event in _stream_graph_events(
|
|
graph, workflow_input, workflow_config, thread_id
|
|
):
|
|
yield event
|
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
|
|
|
if checkpoint_url.startswith("mongodb://"):
|
|
logger.info(f"[{safe_thread_id}] Starting async mongodb checkpointer")
|
|
logger.debug(f"[{safe_thread_id}] Setting up MongoDB connection")
|
|
async with AsyncMongoDBSaver.from_conn_string(
|
|
checkpoint_url
|
|
) as checkpointer:
|
|
logger.debug(f"[{safe_thread_id}] Attaching MongoDB checkpointer to graph")
|
|
graph.checkpointer = checkpointer
|
|
graph.store = in_memory_store
|
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
|
async for event in _stream_graph_events(
|
|
graph, workflow_input, workflow_config, thread_id
|
|
):
|
|
yield event
|
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
|
else:
|
|
logger.debug(f"[{safe_thread_id}] No checkpointer configured, using in-memory graph")
|
|
# Use graph without MongoDB checkpointer
|
|
logger.debug(f"[{safe_thread_id}] Starting to stream graph events")
|
|
async for event in _stream_graph_events(
|
|
graph, workflow_input, workflow_config, thread_id
|
|
):
|
|
yield event
|
|
logger.debug(f"[{safe_thread_id}] Graph event streaming completed")
|
|
|
|
|
|
def _make_event(event_type: str, data: dict[str, any]):
|
|
if data.get("content") == "":
|
|
data.pop("content")
|
|
# Ensure JSON serialization with proper encoding
|
|
try:
|
|
json_data = json.dumps(data, ensure_ascii=False)
|
|
|
|
finish_reason = data.get("finish_reason", "")
|
|
chat_stream_message(
|
|
data.get("thread_id", ""),
|
|
f"event: {event_type}\ndata: {json_data}\n\n",
|
|
finish_reason,
|
|
)
|
|
|
|
return f"event: {event_type}\ndata: {json_data}\n\n"
|
|
except (TypeError, ValueError) as e:
|
|
logger.error(f"Error serializing event data: {e}")
|
|
# Return a safe error event
|
|
error_data = json.dumps({"error": "Serialization failed"}, ensure_ascii=False)
|
|
return f"event: error\ndata: {error_data}\n\n"
|
|
|
|
|
|
@app.post("/api/tts")
|
|
async def text_to_speech(request: TTSRequest):
|
|
"""Convert text to speech using volcengine TTS API."""
|
|
app_id = get_str_env("VOLCENGINE_TTS_APPID", "")
|
|
if not app_id:
|
|
raise HTTPException(status_code=400, detail="VOLCENGINE_TTS_APPID is not set")
|
|
access_token = get_str_env("VOLCENGINE_TTS_ACCESS_TOKEN", "")
|
|
if not access_token:
|
|
raise HTTPException(
|
|
status_code=400, detail="VOLCENGINE_TTS_ACCESS_TOKEN is not set"
|
|
)
|
|
|
|
try:
|
|
cluster = get_str_env("VOLCENGINE_TTS_CLUSTER", "volcano_tts")
|
|
voice_type = get_str_env("VOLCENGINE_TTS_VOICE_TYPE", "BV700_V2_streaming")
|
|
|
|
tts_client = VolcengineTTS(
|
|
appid=app_id,
|
|
access_token=access_token,
|
|
cluster=cluster,
|
|
voice_type=voice_type,
|
|
)
|
|
# Call the TTS API
|
|
result = tts_client.text_to_speech(
|
|
text=request.text[:1024],
|
|
encoding=request.encoding,
|
|
speed_ratio=request.speed_ratio,
|
|
volume_ratio=request.volume_ratio,
|
|
pitch_ratio=request.pitch_ratio,
|
|
text_type=request.text_type,
|
|
with_frontend=request.with_frontend,
|
|
frontend_type=request.frontend_type,
|
|
)
|
|
|
|
if not result["success"]:
|
|
raise HTTPException(status_code=500, detail=str(result["error"]))
|
|
|
|
# Decode the base64 audio data
|
|
audio_data = base64.b64decode(result["audio_data"])
|
|
|
|
# Return the audio file
|
|
return Response(
|
|
content=audio_data,
|
|
media_type=f"audio/{request.encoding}",
|
|
headers={
|
|
"Content-Disposition": (
|
|
f"attachment; filename=tts_output.{request.encoding}"
|
|
)
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error in TTS endpoint: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
@app.post("/api/podcast/generate")
|
|
async def generate_podcast(request: GeneratePodcastRequest):
|
|
try:
|
|
report_content = request.content
|
|
print(report_content)
|
|
workflow = build_podcast_graph()
|
|
final_state = workflow.invoke({"input": report_content})
|
|
audio_bytes = final_state["output"]
|
|
return Response(content=audio_bytes, media_type="audio/mp3")
|
|
except Exception as e:
|
|
logger.exception(f"Error occurred during podcast generation: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
@app.post("/api/ppt/generate")
|
|
async def generate_ppt(request: GeneratePPTRequest):
|
|
try:
|
|
report_content = request.content
|
|
print(report_content)
|
|
workflow = build_ppt_graph()
|
|
final_state = workflow.invoke({"input": report_content, "locale": request.locale})
|
|
generated_file_path = final_state["generated_file_path"]
|
|
with open(generated_file_path, "rb") as f:
|
|
ppt_bytes = f.read()
|
|
return Response(
|
|
content=ppt_bytes,
|
|
media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Error occurred during ppt generation: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
@app.post("/api/prose/generate")
|
|
async def generate_prose(request: GenerateProseRequest):
|
|
try:
|
|
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
|
logger.info(f"Generating prose for prompt: {sanitized_prompt}")
|
|
workflow = build_prose_graph()
|
|
events = workflow.astream(
|
|
{
|
|
"content": request.prompt,
|
|
"option": request.option,
|
|
"command": request.command,
|
|
},
|
|
stream_mode="messages",
|
|
subgraphs=True,
|
|
)
|
|
return StreamingResponse(
|
|
(f"data: {event[0].content}\n\n" async for _, event in events),
|
|
media_type="text/event-stream",
|
|
)
|
|
except Exception as e:
|
|
logger.exception(f"Error occurred during prose generation: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
@app.post("/api/prompt/enhance")
|
|
async def enhance_prompt(request: EnhancePromptRequest):
|
|
try:
|
|
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
|
logger.info(f"Enhancing prompt: {sanitized_prompt}")
|
|
|
|
# Convert string report_style to ReportStyle enum
|
|
report_style = None
|
|
if request.report_style:
|
|
try:
|
|
# Handle both uppercase and lowercase input
|
|
style_mapping = {
|
|
"ACADEMIC": ReportStyle.ACADEMIC,
|
|
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
|
|
"NEWS": ReportStyle.NEWS,
|
|
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
|
|
"STRATEGIC_INVESTMENT": ReportStyle.STRATEGIC_INVESTMENT,
|
|
}
|
|
report_style = style_mapping.get(
|
|
request.report_style.upper(), ReportStyle.ACADEMIC
|
|
)
|
|
except Exception:
|
|
# If invalid style, default to ACADEMIC
|
|
report_style = ReportStyle.ACADEMIC
|
|
else:
|
|
report_style = ReportStyle.ACADEMIC
|
|
|
|
workflow = build_prompt_enhancer_graph()
|
|
final_state = workflow.invoke(
|
|
{
|
|
"prompt": request.prompt,
|
|
"context": request.context,
|
|
"report_style": report_style,
|
|
}
|
|
)
|
|
return {"result": final_state["output"]}
|
|
except Exception as e:
|
|
logger.exception(f"Error occurred during prompt enhancement: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
|
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
|
"""Get information about an MCP server."""
|
|
# Check if MCP server configuration is enabled
|
|
if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="MCP server configuration is disabled. Set ENABLE_MCP_SERVER_CONFIGURATION=true to enable MCP features.",
|
|
)
|
|
|
|
try:
|
|
# Set default timeout with a longer value for this endpoint
|
|
timeout = 300 # Default to 300 seconds for this endpoint
|
|
|
|
# Use custom timeout from request if provided
|
|
if request.timeout_seconds is not None:
|
|
timeout = request.timeout_seconds
|
|
|
|
# Load tools from the MCP server using the utility function
|
|
tools = await load_mcp_tools(
|
|
server_type=request.transport,
|
|
command=request.command,
|
|
args=request.args,
|
|
url=request.url,
|
|
env=request.env,
|
|
headers=request.headers,
|
|
timeout_seconds=timeout,
|
|
)
|
|
|
|
# Create the response with tools
|
|
response = MCPServerMetadataResponse(
|
|
transport=request.transport,
|
|
command=request.command,
|
|
args=request.args,
|
|
url=request.url,
|
|
env=request.env,
|
|
headers=request.headers,
|
|
tools=tools,
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
logger.exception(f"Error in MCP server metadata endpoint: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
|
|
|
|
|
@app.get("/api/rag/config", response_model=RAGConfigResponse)
|
|
async def rag_config():
|
|
"""Get the config of the RAG."""
|
|
return RAGConfigResponse(provider=SELECTED_RAG_PROVIDER)
|
|
|
|
|
|
@app.get("/api/rag/resources", response_model=RAGResourcesResponse)
|
|
async def rag_resources(request: Annotated[RAGResourceRequest, Query()]):
|
|
"""Get the resources of the RAG."""
|
|
retriever = build_retriever()
|
|
if retriever:
|
|
return RAGResourcesResponse(resources=retriever.list_resources(request.query))
|
|
return RAGResourcesResponse(resources=[])
|
|
|
|
|
|
MAX_UPLOAD_SIZE_BYTES = 10 * 1024 * 1024 # 10 MB
|
|
ALLOWED_EXTENSIONS = {".md", ".txt"}
|
|
|
|
|
|
def _sanitize_filename(filename: str) -> str:
|
|
"""Sanitize filename to prevent path traversal attacks."""
|
|
# Extract only the base filename, removing any path components
|
|
basename = os.path.basename(filename)
|
|
# Remove any null bytes or other dangerous characters
|
|
sanitized = basename.replace("\x00", "").strip()
|
|
# Ensure filename is not empty after sanitization
|
|
if not sanitized or sanitized in (".", ".."):
|
|
return "unnamed_file"
|
|
return sanitized
|
|
|
|
|
|
@app.post("/api/rag/upload", response_model=Resource)
|
|
async def upload_rag_resource(file: UploadFile):
|
|
# Validate filename exists
|
|
if not file.filename:
|
|
raise HTTPException(status_code=400, detail="Filename is required for upload")
|
|
|
|
# Sanitize filename to prevent path traversal
|
|
safe_filename = _sanitize_filename(file.filename)
|
|
|
|
# Validate file extension
|
|
_, ext = os.path.splitext(safe_filename.lower())
|
|
if ext not in ALLOWED_EXTENSIONS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Invalid file type. Only {', '.join(ALLOWED_EXTENSIONS)} files are allowed.",
|
|
)
|
|
|
|
# Read content with size limit check
|
|
content = await file.read()
|
|
if len(content) == 0:
|
|
raise HTTPException(status_code=400, detail="Cannot upload an empty file")
|
|
if len(content) > MAX_UPLOAD_SIZE_BYTES:
|
|
raise HTTPException(
|
|
status_code=413,
|
|
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE_BYTES // (1024 * 1024)} MB.",
|
|
)
|
|
|
|
retriever = build_retriever()
|
|
if not retriever:
|
|
raise HTTPException(status_code=500, detail="RAG provider not configured")
|
|
try:
|
|
return retriever.ingest_file(content, safe_filename)
|
|
except NotImplementedError:
|
|
raise HTTPException(
|
|
status_code=501, detail="Upload not supported by current RAG provider"
|
|
)
|
|
except ValueError as exc:
|
|
# Invalid user input or unsupported file content; treat as a client error
|
|
logger.warning("Invalid RAG resource upload: %s", exc)
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Invalid RAG resource. Please check the file and try again.",
|
|
)
|
|
except RuntimeError as exc:
|
|
# Internal error during ingestion; log and return a generic server error
|
|
logger.exception("Runtime error while ingesting RAG resource: %s", exc)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Failed to ingest RAG resource due to an internal error.",
|
|
)
|
|
|
|
|
|
@app.get("/api/config", response_model=ConfigResponse)
|
|
async def config():
|
|
"""Get the config of the server."""
|
|
return ConfigResponse(
|
|
rag=RAGConfigResponse(provider=SELECTED_RAG_PROVIDER),
|
|
models=get_configured_llm_models(),
|
|
)
|