import type { HumanMessage } from "@langchain/core/messages"; import type { ThreadsClient } from "@langchain/langgraph-sdk/client"; import { useStream, type UseStream } from "@langchain/langgraph-sdk/react"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useCallback } from "react"; import type { PromptInputMessage } from "@/components/ai-elements/prompt-input"; import { getAPIClient } from "../api"; import type { AgentThread, AgentThreadContext, AgentThreadState, } from "./types"; export function useThreadStream({ threadId, isNewThread, }: { isNewThread: boolean; threadId: string | null | undefined; }) { const queryClient = useQueryClient(); const thread = useStream({ client: getAPIClient(), assistantId: "lead_agent", threadId: isNewThread ? undefined : threadId, reconnectOnMount: true, fetchStateHistory: true, onFinish(state) { // void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); queryClient.setQueriesData( { queryKey: ["threads", "search"], exact: false, }, (oldData: Array) => { return oldData.map((t) => { if (t.thread_id === threadId) { return { ...t, values: { ...t.values, title: state.values.title, }, }; } return t; }); }, ); }, }); return thread; } export function useSubmitThread({ threadId, thread, threadContext, isNewThread, afterSubmit, }: { isNewThread: boolean; threadId: string | null | undefined; thread: UseStream; threadContext: Omit; afterSubmit?: () => void; }) { const queryClient = useQueryClient(); const callback = useCallback( async (message: PromptInputMessage) => { const text = message.text.trim(); await thread.submit( { messages: [ { type: "human", content: [ { type: "text", text, }, ], }, ] as HumanMessage[], }, { threadId: isNewThread ? threadId! : undefined, streamSubgraphs: true, streamResumable: true, config: { recursion_limit: 100, }, context: { ...threadContext, thread_id: threadId, }, }, ); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); afterSubmit?.(); }, [thread, isNewThread, threadId, threadContext, queryClient, afterSubmit], ); return callback; } export function useThreads( params: Parameters[0] = { limit: 50, sortBy: "updated_at", sortOrder: "desc", }, ) { const apiClient = getAPIClient(); return useQuery({ queryKey: ["threads", "search", params], queryFn: async () => { const response = await apiClient.threads.search(params); return response as AgentThread[]; }, }); } export function useDeleteThread() { const queryClient = useQueryClient(); const apiClient = getAPIClient(); return useMutation({ mutationFn: async ({ threadId }: { threadId: string }) => { await apiClient.threads.delete(threadId); }, onSuccess(_, { threadId }) { queryClient.setQueriesData( { queryKey: ["threads", "search"], exact: false, }, (oldData: Array) => { return oldData.filter((t) => t.thread_id !== threadId); }, ); }, }); }