mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-21 05:14:45 +08:00
feat: implement basic server logic
This commit is contained in:
107
src/server/app.py
Normal file
107
src/server/app.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import List, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.messages import AIMessageChunk, ToolMessage
|
||||
|
||||
from src.graph.builder import build_graph
|
||||
from src.server.chat_request import ChatMessage, ChatRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="Lite Deep Research API",
|
||||
description="API for Lite Deep Research",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
graph = build_graph()
|
||||
|
||||
|
||||
@app.post("/api/chat/stream")
|
||||
async def chat_stream(request: ChatRequest):
|
||||
thread_id = request.thread_id
|
||||
if thread_id == "__default__":
|
||||
thread_id = str(uuid4())
|
||||
return StreamingResponse(
|
||||
_astream_workflow_generator(
|
||||
request.model_dump()["messages"],
|
||||
thread_id,
|
||||
request.max_plan_iterations,
|
||||
request.max_step_num,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
async def _astream_workflow_generator(
|
||||
messages: List[ChatMessage],
|
||||
thread_id: str,
|
||||
max_plan_iterations: int,
|
||||
max_step_num: int,
|
||||
):
|
||||
async for agent, _, event_data in graph.astream(
|
||||
{"messages": messages},
|
||||
config={
|
||||
"thread_id": thread_id,
|
||||
"max_plan_iterations": max_plan_iterations,
|
||||
"max_step_num": max_step_num,
|
||||
},
|
||||
stream_mode=["messages"],
|
||||
subgraphs=True,
|
||||
):
|
||||
message_chunk, message_metadata = cast(
|
||||
tuple[AIMessageChunk, dict[str, any]], event_data
|
||||
)
|
||||
event_stream_message: dict[str, any] = {
|
||||
"thread_id": thread_id,
|
||||
"agent": agent[0].split(":")[0],
|
||||
"id": message_chunk.id,
|
||||
"role": "assistant",
|
||||
"content": message_chunk.content,
|
||||
}
|
||||
if message_chunk.response_metadata.get("finish_reason"):
|
||||
event_stream_message["finish_reason"] = message_chunk.response_metadata.get(
|
||||
"finish_reason"
|
||||
)
|
||||
if isinstance(message_chunk, ToolMessage):
|
||||
# Tool Message - Return the result of the tool call
|
||||
event_stream_message["tool_call_id"] = message_chunk.tool_call_id
|
||||
yield _make_event("tool_call_result", event_stream_message)
|
||||
else:
|
||||
# AI Message - Raw message tokens
|
||||
if message_chunk.tool_calls:
|
||||
# AI Message - Tool Call
|
||||
event_stream_message["tool_calls"] = message_chunk.tool_calls
|
||||
event_stream_message["tool_call_chunks"] = (
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
yield _make_event("tool_calls", event_stream_message)
|
||||
elif message_chunk.tool_call_chunks:
|
||||
# AI Message - Tool Call Chunks
|
||||
event_stream_message["tool_call_chunks"] = (
|
||||
message_chunk.tool_call_chunks
|
||||
)
|
||||
yield _make_event("tool_call_chunks", event_stream_message)
|
||||
else:
|
||||
# AI Message - Raw message tokens
|
||||
yield _make_event("message_chunk", event_stream_message)
|
||||
|
||||
|
||||
def _make_event(event_type: str, data: dict[str, any]):
|
||||
if data.get("content") == "":
|
||||
data.pop("content")
|
||||
return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
Reference in New Issue
Block a user