import json import logging from typing import List, cast from uuid import uuid4 from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from langchain_core.messages import AIMessageChunk, ToolMessage from langgraph.types import Command from src.graph.builder import build_graph from src.server.chat_request import ChatMessage, ChatRequest logger = logging.getLogger(__name__) app = FastAPI( title="Lite Deep Research API", description="API for Lite Deep Research", version="0.1.0", ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) graph = build_graph() @app.post("/api/chat/stream") async def chat_stream(request: ChatRequest): thread_id = request.thread_id if thread_id == "__default__": thread_id = str(uuid4()) return StreamingResponse( _astream_workflow_generator( request.model_dump()["messages"], thread_id, request.max_plan_iterations, request.max_step_num, request.auto_accepted_plan, request.feedback, ), media_type="text/event-stream", ) async def _astream_workflow_generator( messages: List[ChatMessage], thread_id: str, max_plan_iterations: int, max_step_num: int, auto_accepted_plan: bool, feedback: str, ): input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan} if not auto_accepted_plan and feedback: input_ = Command(resume=feedback) async for agent, _, event_data in graph.astream( input_, config={ "thread_id": thread_id, "max_plan_iterations": max_plan_iterations, "max_step_num": max_step_num, }, stream_mode=["messages"], subgraphs=True, ): message_chunk, message_metadata = cast( tuple[AIMessageChunk, dict[str, any]], event_data ) event_stream_message: dict[str, any] = { "thread_id": thread_id, "agent": agent[0].split(":")[0], "id": message_chunk.id, "role": "assistant", "content": message_chunk.content, } if message_chunk.response_metadata.get("finish_reason"): event_stream_message["finish_reason"] = message_chunk.response_metadata.get( "finish_reason" ) if isinstance(message_chunk, ToolMessage): # Tool Message - Return the result of the tool call event_stream_message["tool_call_id"] = message_chunk.tool_call_id yield _make_event("tool_call_result", event_stream_message) else: # AI Message - Raw message tokens if message_chunk.tool_calls: # AI Message - Tool Call event_stream_message["tool_calls"] = message_chunk.tool_calls event_stream_message["tool_call_chunks"] = ( message_chunk.tool_call_chunks ) yield _make_event("tool_calls", event_stream_message) elif message_chunk.tool_call_chunks: # AI Message - Tool Call Chunks event_stream_message["tool_call_chunks"] = ( message_chunk.tool_call_chunks ) yield _make_event("tool_call_chunks", event_stream_message) else: # AI Message - Raw message tokens yield _make_event("message_chunk", event_stream_message) def _make_event(event_type: str, data: dict[str, any]): if data.get("content") == "": data.pop("content") return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"