diff --git a/web/src/app/chat/components/message-list-view.tsx b/web/src/app/chat/components/message-list-view.tsx index d745006..c4f2bf8 100644 --- a/web/src/app/chat/components/message-list-view.tsx +++ b/web/src/app/chat/components/message-list-view.tsx @@ -27,8 +27,11 @@ import type { Message, Option } from "~/core/messages"; import { closeResearch, openResearch, + useLastFeedbackMessageId, + useLastInterruptMessage, useMessage, - useResearchTitle, + useMessageIds, + useResearchMessage, useStore, } from "~/core/store"; import { parseJSON } from "~/core/utils"; @@ -47,27 +50,9 @@ export function MessageListView({ ) => void; }) { const scrollContainerRef = useRef(null); - const messageIds = useStore((state) => state.messageIds); - const interruptMessage = useStore((state) => { - if (messageIds.length >= 2) { - const lastMessage = state.messages.get( - messageIds[messageIds.length - 1]!, - ); - return lastMessage?.finishReason === "interrupt" ? lastMessage : null; - } - return null; - }); - const waitingForFeedbackMessageId = useStore((state) => { - if (messageIds.length >= 2) { - const lastMessage = state.messages.get( - messageIds[messageIds.length - 1]!, - ); - if (lastMessage && lastMessage.finishReason === "interrupt") { - return state.messageIds[state.messageIds.length - 2]; - } - } - return null; - }); + const messageIds = useMessageIds(); + const interruptMessage = useLastInterruptMessage(); + const waitingForFeedbackMessageId = useLastFeedbackMessageId(); const responding = useStore((state) => state.responding); const noOngoingResearch = useStore( (state) => state.ongoingResearchId === null, @@ -138,9 +123,10 @@ function MessageListItem({ onToggleResearch?: () => void; }) { const message = useMessage(messageId); - const startOfResearch = useStore((state) => - state.researchIds.includes(messageId), - ); + const researchIds = useStore((state) => state.researchIds); + const startOfResearch = useMemo(() => { + return researchIds.includes(messageId); + }, [researchIds, messageId]); if (message) { if ( message.role === "user" || @@ -214,90 +200,92 @@ function MessageListItem({ } return null; } +} - function MessageBubble({ - className, - message, - children, - }: { - className?: string; - message: Message; - children: React.ReactNode; - }) { - return ( -
- {children} -
- ); - } +function MessageBubble({ + className, + message, + children, +}: { + className?: string; + message: Message; + children: React.ReactNode; +}) { + return ( +
+ {children} +
+ ); +} - function ResearchCard({ - className, - researchId, - onToggleResearch, - }: { - className?: string; - researchId: string; - onToggleResearch?: () => void; - }) { - const reportId = useStore((state) => - state.researchReportIds.get(researchId), - ); - const hasReport = useStore((state) => - state.researchReportIds.has(researchId), - ); - const reportGenerating = useStore( - (state) => hasReport && state.messages.get(reportId!)!.isStreaming, - ); - const openResearchId = useStore((state) => state.openResearchId); - const state = useMemo(() => { - if (hasReport) { - return reportGenerating ? "Generating report..." : "Report generated"; - } - return "Researching..."; - }, [hasReport, reportGenerating]); - const title = useResearchTitle(researchId); - const handleOpen = useCallback(() => { - if (openResearchId === researchId) { - closeResearch(); - } else { - openResearch(researchId); - } - onToggleResearch?.(); - }, [openResearchId, researchId, onToggleResearch]); - return ( - - - - - {title !== undefined && title !== "" ? title : "Deep Research"} - - - - -
- - {state} - - -
-
-
- ); - } +function ResearchCard({ + className, + researchId, + onToggleResearch, +}: { + className?: string; + researchId: string; + onToggleResearch?: () => void; +}) { + const reportId = useStore((state) => state.researchReportIds.get(researchId)); + const hasReport = reportId !== undefined; + const reportGenerating = useStore( + (state) => hasReport && state.messages.get(reportId)!.isStreaming, + ); + const openResearchId = useStore((state) => state.openResearchId); + const state = useMemo(() => { + if (hasReport) { + return reportGenerating ? "Generating report..." : "Report generated"; + } + return "Researching..."; + }, [hasReport, reportGenerating]); + const msg = useResearchMessage(researchId); + const title = useMemo(() => { + if (msg) { + return parseJSON(msg.content ?? "", { title: "" }).title; + } + return undefined; + }, [msg]); + const handleOpen = useCallback(() => { + if (openResearchId === researchId) { + closeResearch(); + } else { + openResearch(researchId); + } + onToggleResearch?.(); + }, [openResearchId, researchId, onToggleResearch]); + return ( + + + + + {title !== undefined && title !== "" ? title : "Deep Research"} + + + + +
+ + {state} + + +
+
+
+ ); } const GREETINGS = ["Cool", "Sounds great", "Looks good", "Great", "Awesome"]; diff --git a/web/src/app/chat/components/messages-block.tsx b/web/src/app/chat/components/messages-block.tsx index 2fe6ea4..bf36494 100644 --- a/web/src/app/chat/components/messages-block.tsx +++ b/web/src/app/chat/components/messages-block.tsx @@ -17,7 +17,7 @@ import { fastForwardReplay } from "~/core/api"; import { useReplayMetadata } from "~/core/api/hooks"; import type { Option } from "~/core/messages"; import { useReplay } from "~/core/replay"; -import { sendMessage, useStore } from "~/core/store"; +import { sendMessage, useMessageIds, useStore } from "~/core/store"; import { env } from "~/env"; import { cn } from "~/lib/utils"; @@ -27,7 +27,8 @@ import { MessageListView } from "./message-list-view"; import { Welcome } from "./welcome"; export function MessagesBlock({ className }: { className?: string }) { - const messageCount = useStore((state) => state.messageIds.length); + const messageIds = useMessageIds(); + const messageCount = messageIds.length; const responding = useStore((state) => state.responding); const { isReplay } = useReplay(); const { title: replayTitle, hasError: replayHasError } = useReplayMetadata(); diff --git a/web/src/core/store/store.ts b/web/src/core/store/store.ts index ee333fe..3078842 100644 --- a/web/src/core/store/store.ts +++ b/web/src/core/store/store.ts @@ -4,6 +4,7 @@ import { nanoid } from "nanoid"; import { toast } from "sonner"; import { create } from "zustand"; +import { useShallow } from "zustand/react/shallow"; import { chatStream, generatePodcast } from "../api"; import type { Message } from "../messages"; @@ -305,17 +306,54 @@ export async function listenToPodcast(researchId: string) { } } -export function useResearchTitle(researchId: string) { - const planMessage = useMessage( - useStore.getState().researchPlanIds.get(researchId), +export function useResearchMessage(researchId: string) { + return useStore( + useShallow((state) => { + const messageId = state.researchPlanIds.get(researchId); + return messageId ? state.messages.get(messageId) : undefined; + }), ); - return planMessage - ? parseJSON(planMessage.content, { title: "" }).title - : undefined; } export function useMessage(messageId: string | null | undefined) { - return useStore((state) => - messageId ? state.messages.get(messageId) : undefined, + return useStore( + useShallow((state) => + messageId ? state.messages.get(messageId) : undefined, + ), ); } + +export function useMessageIds() { + return useStore(useShallow((state) => state.messageIds)); +} + +export function useLastInterruptMessage() { + return useStore( + useShallow((state) => { + if (state.messageIds.length >= 2) { + const lastMessage = state.messages.get( + state.messageIds[state.messageIds.length - 1]!, + ); + return lastMessage?.finishReason === "interrupt" ? lastMessage : null; + } + return null; + }), + ); +} + +export function useLastFeedbackMessageId() { + const waitingForFeedbackMessageId = useStore( + useShallow((state) => { + if (state.messageIds.length >= 2) { + const lastMessage = state.messages.get( + state.messageIds[state.messageIds.length - 1]!, + ); + if (lastMessage && lastMessage.finishReason === "interrupt") { + return state.messageIds[state.messageIds.length - 2]; + } + } + return null; + }), + ); + return waitingForFeedbackMessageId; +}