import json import logging from typing import List, cast from uuid import uuid4 from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from langchain_core.messages import AIMessageChunk, ToolMessage from langgraph.types import Command from src.graph.builder import build_graph from src.server.chat_request import ChatMessage, ChatRequest logger = logging.getLogger(__name__) app = FastAPI( title="Lite Deep Research API", description="API for Lite Deep Research", version="0.1.0", ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allows all origins allow_credentials=True, allow_methods=["*"], # Allows all methods allow_headers=["*"], # Allows all headers ) graph = build_graph() @app.post("/api/chat/stream") async def chat_stream(request: ChatRequest): thread_id = request.thread_id if thread_id == "__default__": thread_id = str(uuid4()) return StreamingResponse( _astream_workflow_generator( request.model_dump()["messages"], thread_id, request.max_plan_iterations, request.max_step_num, request.auto_accepted_plan, request.interrupt_feedback, ), media_type="text/event-stream", ) async def _astream_workflow_generator( messages: List[ChatMessage], thread_id: str, max_plan_iterations: int, max_step_num: int, auto_accepted_plan: bool, interrupt_feedback: str, ): input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan} if not auto_accepted_plan and interrupt_feedback: resume_msg = f"[{interrupt_feedback}]" # add the last message to the resume message if messages: resume_msg += f" {messages[-1]["content"]}" input_ = Command(resume=resume_msg) async for agent, _, event_data in graph.astream( input_, config={ "thread_id": thread_id, "max_plan_iterations": max_plan_iterations, "max_step_num": max_step_num, }, stream_mode=["messages", "updates"], subgraphs=True, ): if isinstance(event_data, dict): if "__interrupt__" in event_data: yield _make_event( "interrupt", { "thread_id": thread_id, "id": event_data["__interrupt__"][0].ns[0], "role": "assistant", "content": event_data["__interrupt__"][0].value, "finish_reason": "interrupt", "options": [ {"text": "Edit plan", "value": "edit_plan"}, {"text": "Start research", "value": "accepted"}, ], }, ) continue message_chunk, message_metadata = cast( tuple[AIMessageChunk, dict[str, any]], event_data ) event_stream_message: dict[str, any] = { "thread_id": thread_id, "agent": agent[0].split(":")[0], "id": message_chunk.id, "role": "assistant", "content": message_chunk.content, } if message_chunk.response_metadata.get("finish_reason"): event_stream_message["finish_reason"] = message_chunk.response_metadata.get( "finish_reason" ) if isinstance(message_chunk, ToolMessage): # Tool Message - Return the result of the tool call event_stream_message["tool_call_id"] = message_chunk.tool_call_id yield _make_event("tool_call_result", event_stream_message) else: # AI Message - Raw message tokens if message_chunk.tool_calls: # AI Message - Tool Call event_stream_message["tool_calls"] = message_chunk.tool_calls event_stream_message["tool_call_chunks"] = ( message_chunk.tool_call_chunks ) yield _make_event("tool_calls", event_stream_message) elif message_chunk.tool_call_chunks: # AI Message - Tool Call Chunks event_stream_message["tool_call_chunks"] = ( message_chunk.tool_call_chunks ) yield _make_event("tool_call_chunks", event_stream_message) else: # AI Message - Raw message tokens yield _make_event("message_chunk", event_stream_message) def _make_event(event_type: str, data: dict[str, any]): if data.get("content") == "": data.pop("content") return f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"