Files
deer-flow/src/server/app.py
2025-04-14 19:53:00 +08:00

116 lines
3.8 KiB
Python

import json
import logging
from typing import List, cast
from uuid import uuid4
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from langchain_core.messages import AIMessageChunk, ToolMessage
from langgraph.types import Command
from src.graph.builder import build_graph
from src.server.chat_request import ChatMessage, ChatRequest
logger = logging.getLogger(__name__)
app = FastAPI(
title="Lite Deep Research API",
description="API for Lite Deep Research",
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
graph = build_graph()
@app.post("/api/chat/stream")
async def chat_stream(request: ChatRequest):
thread_id = request.thread_id
if thread_id == "__default__":
thread_id = str(uuid4())
return StreamingResponse(
_astream_workflow_generator(
request.model_dump()["messages"],
thread_id,
request.max_plan_iterations,
request.max_step_num,
request.auto_accepted_plan,
request.feedback,
),
media_type="text/event-stream",
)
async def _astream_workflow_generator(
messages: List[ChatMessage],
thread_id: str,
max_plan_iterations: int,
max_step_num: int,
auto_accepted_plan: bool,
feedback: str,
):
input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan}
if not auto_accepted_plan and feedback:
input_ = Command(resume=feedback)
async for agent, _, event_data in graph.astream(
input_,
config={
"thread_id": thread_id,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
},
stream_mode=["messages"],
subgraphs=True,
):
message_chunk, message_metadata = cast(
tuple[AIMessageChunk, dict[str, any]], event_data
)
event_stream_message: dict[str, any] = {
"thread_id": thread_id,
"agent": agent[0].split(":")[0],
"id": message_chunk.id,
"role": "assistant",
"content": message_chunk.content,
}
if message_chunk.response_metadata.get("finish_reason"):
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
"finish_reason"
)
if isinstance(message_chunk, ToolMessage):
# Tool Message - Return the result of the tool call
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
yield _make_event("tool_call_result", event_stream_message)
else:
# AI Message - Raw message tokens
if message_chunk.tool_calls:
# AI Message - Tool Call
event_stream_message["tool_calls"] = message_chunk.tool_calls
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_calls", event_stream_message)
elif message_chunk.tool_call_chunks:
# AI Message - Tool Call Chunks
event_stream_message["tool_call_chunks"] = (
message_chunk.tool_call_chunks
)
yield _make_event("tool_call_chunks", event_stream_message)
else:
# AI Message - Raw message tokens
yield _make_event("message_chunk", event_stream_message)
def _make_event(event_type: str, data: dict[str, any]):
if data.get("content") == "":
data.pop("content")
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"