diff --git a/pages/chat/@id/+Page.tsx b/pages/chat/@id/+Page.tsx index a3d7993..2188c4b 100644 --- a/pages/chat/@id/+Page.tsx +++ b/pages/chat/@id/+Page.tsx @@ -19,10 +19,20 @@ import { import { usePageContext } from "vike-react/usePageContext"; import { useData } from "vike-react/useData"; import type { Data } from "./+data"; -import type { CommittedMessage, DraftMessage } from "../../../types"; +import type { + CommittedMessage, + DraftMessage, + OtherParameters, +} from "../../../types"; import Markdown from "react-markdown"; -import { IconTrash, IconEdit, IconCheck, IconX } from "@tabler/icons-react"; -import { useTRPC } from "../../../trpc/client"; +import { + IconTrash, + IconEdit, + IconCheck, + IconX, + IconLoaderQuarter, +} from "@tabler/icons-react"; +import { useTRPC, useTRPCClient } from "../../../trpc/client"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { nanoid } from "nanoid"; import type { Conversation } from "../../../database/common"; @@ -49,6 +59,7 @@ export default function ChatPage() { const setParameters = useStore((state) => state.setParameters); const setLoading = useStore((state) => state.setLoading); const trpc = useTRPC(); + const trpcClient = useTRPCClient(); const queryClient = useQueryClient(); const messagesResult = useQuery( @@ -334,83 +345,94 @@ export default function ChatPage() { }) ); - const sendMessage = useMutation( - trpc.chat.sendMessage.mutationOptions({ - onMutate: async ({ - conversationId, - messages, - systemPrompt, - parameters, - }) => { - /** Cancel affected queries that may be in-flight: */ - await queryClient.cancelQueries({ - queryKey: trpc.chat.messages.fetchByConversationId.queryKey({ - conversationId, - }), - }); - /** Optimistically update the affected queries in react-query's cache: */ - const previousMessages: Array | undefined = - await queryClient.getQueryData( - trpc.chat.messages.fetchByConversationId.queryKey({ - conversationId, - }) - ); - if (!previousMessages) { - return { - previousMessages: [], - newMessages: [], - }; + // Get state from Zustand store + const sendMessageStatus = useStore((state) => state.sendMessageStatus); + const isSendingMessage = useStore((state) => state.isSendingMessage); + const setSendMessageStatus = useStore((state) => state.setSendMessageStatus); + const setIsSendingMessage = useStore((state) => state.setIsSendingMessage); + + // Function to send message using subscription + const sendSubscriptionMessage = async ({ + conversationId, + messages, + systemPrompt, + parameters, + }: { + conversationId: string; + messages: Array; + systemPrompt: string; + parameters: OtherParameters; + }) => { + setIsSendingMessage(true); + setSendMessageStatus(null); + + try { + // Create an abort controller for the subscription + const abortController = new AbortController(); + + // Start the subscription + const subscription = trpcClient.chat.sendMessage.subscribe( + { + conversationId, + messages, + systemPrompt, + parameters, + }, + { + signal: abortController.signal, + onData: (data) => { + setSendMessageStatus(data); + + // If we've completed, update the UI and invalidate queries + if (data.status === "completed") { + setIsSendingMessage(false); + // Invalidate queries to refresh the data + queryClient.invalidateQueries({ + queryKey: trpc.chat.messages.fetchByConversationId.queryKey({ + conversationId, + }), + }); + queryClient.invalidateQueries({ + queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + }); + queryClient.invalidateQueries({ + queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey( + { + conversationId, + } + ), + }); + } else { + setSendMessageStatus(data); + } + }, + onError: (error) => { + console.error("Subscription error:", error); + setIsSendingMessage(false); + setSendMessageStatus({ + status: "error", + message: "An error occurred while sending the message", + }); + }, } - const newMessages: Array = [ - ...previousMessages, - { - /** placeholder id; will be overwritten when we get the true id from the backend */ - id: nanoid(), - conversationId, - // content: messages[messages.length - 1].content, - // role: "user" as const, - ...messages[messages.length - 1], - index: previousMessages.length, - createdAt: new Date().toISOString(), - } as CommittedMessage, - ]; - queryClient.setQueryData( - trpc.chat.messages.fetchByConversationId.queryKey({ - conversationId, - }), - newMessages - ); - return { previousMessages, newMessages }; - }, - onSettled: async (data, variables, context) => { - await queryClient.invalidateQueries({ - queryKey: trpc.chat.messages.fetchByConversationId.queryKey({ - conversationId, - }), - }); - await queryClient.invalidateQueries({ - queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ - conversationId, - }), - }); - await queryClient.invalidateQueries({ - queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({ - conversationId, - }), - }); - }, - onError: async (error, variables, context) => { - console.error(error); - if (!context) return; - queryClient.setQueryData( - trpc.chat.messages.fetchByConversationId.queryKey({ - conversationId, - }), - context.previousMessages - ); - }, - }) - ); + ); + + // Return a function to unsubscribe if needed + return () => { + abortController.abort(); + subscription.unsubscribe(); + }; + } catch (error) { + console.error("Failed to start subscription:", error); + setIsSendingMessage(false); + setSendMessageStatus({ + status: "error", + message: "Failed to start message sending process", + }); + } + }; // State for editing facts const [editingFactId, setEditingFactId] = useState(null); @@ -483,6 +505,8 @@ export default function ChatPage() { }); }} /> + {isSendingMessage && } + {sendMessageStatus && {sendMessageStatus.message}} @@ -504,7 +528,7 @@ export default function ChatPage() { if (e.key === "Enter") { e.preventDefault(); setLoading(true); - await sendMessage.mutateAsync({ + await sendSubscriptionMessage({ conversationId, messages: [ ...(messages || []), diff --git a/pages/chat/trpc.ts b/pages/chat/trpc.ts index d590294..c189fc5 100644 --- a/pages/chat/trpc.ts +++ b/pages/chat/trpc.ts @@ -74,173 +74,213 @@ export const chat = router({ parameters: OtherParameters; } ) - .mutation( - async ({ - input: { conversationId, messages, systemPrompt, parameters }, - }) => { - /** TODO: Save all unsaved messages (i.e. those without an `id`) to the - * database. Is this dangerous? Can an attacker just send a bunch of - * messages, omitting the ids, causing me to save a bunch of them to the - * database? I guess it's no worse than starting new converations, which - * anyone can freely do. */ - const previousRunningSummaryIndex = messages.findLastIndex( - (message) => - typeof (message as CommittedMessage).runningSummary !== "undefined" - ); - const previousRunningSummary = - previousRunningSummaryIndex >= 0 - ? ((messages[previousRunningSummaryIndex] as CommittedMessage) - .runningSummary as string) - : ""; - const messagesSincePreviousRunningSummary = messages.slice( - previousRunningSummaryIndex + 1 - ); - /** Save the incoming message to the database. */ - const insertedUserMessage = await db.messages.create({ - conversationId, - // content: messages[messages.length - 1].content, - // role: "user" as const, - ...messages[messages.length - 1], - index: messages.length - 1, - createdAt: new Date().toISOString(), + .subscription(async function* ({ + input: { conversationId, messages, systemPrompt, parameters }, + }) { + /** TODO: Save all unsaved messages (i.e. those without an `id`) to the + * database. Is this dangerous? Can an attacker just send a bunch of + * messages, omitting the ids, causing me to save a bunch of them to the + * database? I guess it's no worse than starting new converations, which + * anyone can freely do. */ + const previousRunningSummaryIndex = messages.findLastIndex( + (message) => + typeof (message as CommittedMessage).runningSummary !== "undefined" + ); + const previousRunningSummary = + previousRunningSummaryIndex >= 0 + ? ((messages[previousRunningSummaryIndex] as CommittedMessage) + .runningSummary as string) + : ""; + const messagesSincePreviousRunningSummary = messages.slice( + previousRunningSummaryIndex + 1 + ); + + // Emit status update + yield { + status: "saving_user_message", + message: "Saving user message...", + } as const; + + /** Save the incoming message to the database. */ + const insertedUserMessage = await db.messages.create({ + conversationId, + // content: messages[messages.length - 1].content, + // role: "user" as const, + ...messages[messages.length - 1], + index: messages.length - 1, + createdAt: new Date().toISOString(), + }); + + // Emit status update + yield { + status: "generating_response", + message: "Generating AI response...", + } as const; + + /** Generate a new message from the model, but hold-off on adding it to + * the database until we produce the associated running-summary, below. + * The model should be given the conversation summary thus far, and of + * course the user's latest message, unmodified. Invite the model to + * create any tools it needs. The tool needs to be implemented in a + * language which this system can execute; usually an interpretted + * language like Python or JavaScript. */ + const mainResponse = await generateText({ + model: openrouter(MODEL_NAME), + messages: [ + previousRunningSummary === "" + ? { + role: "system" as const, + content: systemPrompt, + } + : { + role: "system" as const, + content: mainSystemPrompt({ + systemPrompt, + previousRunningSummary, + }), + }, + ...messagesSincePreviousRunningSummary.map((m) => ({ + role: m.role, + content: m.parts + .filter((p) => p.type === "text") + .map((p) => p.text) + .join(""), + })), + ], + tools: undefined, + ...parameters, + }); + + // Emit status update + yield { + status: "extracting_facts_from_user", + message: "Extracting facts from user message...", + } as const; + + /** Extract Facts from the user's message, and add them to the database, + * linking the Facts with the messages they came from. (Yes, this should + * be done *after* the model response, not before; because when we run a + * query to find Facts to inject into the context sent to the model, we + * don't want Facts from the user's current message to be candidates for + * injection, because we're sending the user's message unadulterated to + * the model; there's no reason to inject the same Facts that the model is + * already using to generate its response.) */ + const factsFromUserMessageResponse = + await factsCaller.extractFromNewMessages({ + previousRunningSummary, + messagesSincePreviousRunningSummary: [], + newMessages: messagesSincePreviousRunningSummary, + }); + const insertedFactsFromUserMessage = await db.facts.createMany( + factsFromUserMessageResponse.object.facts.map((fact) => ({ + userId: "019900bb-61b3-7333-b760-b27784dfe33b", + sourceMessageId: insertedUserMessage.id, + content: fact, + })) + ); + + // Emit status update + yield { + status: "generating_summary", + message: "Generating conversation summary...", + } as const; + + /** Produce a running summary of the conversation, and save that along + * with the model's response to the database. The new running summary is + * based on the previous running summary combined with the all messages + * since that summary was produced. */ + const runningSummaryResponse = + await messagesCaller.generateRunningSummary({ + messagesSincePreviousRunningSummary, + mainResponseContent: mainResponse.text, + previousRunningSummary, }); + const insertedAssistantMessage = await db.messages.create({ + conversationId, + // content: mainResponse.text, + parts: [{ type: "text", text: mainResponse.text }], + runningSummary: runningSummaryResponse.text, + role: "assistant" as const, + index: messages.length, + createdAt: new Date().toISOString(), + }); - /** Generate a new message from the model, but hold-off on adding it to - * the database until we produce the associated running-summary, below. - * The model should be given the conversation summary thus far, and of - * course the user's latest message, unmodified. Invite the model to - * create any tools it needs. The tool needs to be implemented in a - * language which this system can execute; usually an interpretted - * language like Python or JavaScript. */ - const mainResponse = await generateText({ - model: openrouter(MODEL_NAME), - messages: [ - previousRunningSummary === "" - ? { - role: "system" as const, - content: systemPrompt, - } - : { - role: "system" as const, - content: mainSystemPrompt({ - systemPrompt, - previousRunningSummary, - }), - }, - ...messagesSincePreviousRunningSummary.map((m) => ({ - role: m.role, - content: m.parts - .filter((p) => p.type === "text") - .map((p) => p.text) - .join(""), - })), + // Emit status update + yield { + status: "extracting_facts_from_assistant", + message: "Extracting facts from assistant response...", + } as const; + + /** Extract Facts from the model's response, and add them to the database, + * linking the Facts with the messages they came from. */ + const factsFromAssistantMessageResponse = + await factsCaller.extractFromNewMessages({ + previousRunningSummary, + messagesSincePreviousRunningSummary, + newMessages: [ + { + role: "assistant" as const, + // content: mainResponse.text, + parts: [{ type: "text", text: mainResponse.text }], + }, ], - tools: undefined, - ...parameters, }); - /** Extract Facts from the user's message, and add them to the database, - * linking the Facts with the messages they came from. (Yes, this should - * be done *after* the model response, not before; because when we run a - * query to find Facts to inject into the context sent to the model, we - * don't want Facts from the user's current message to be candidates for - * injection, because we're sending the user's message unadulterated to - * the model; there's no reason to inject the same Facts that the model is - * already using to generate its response.) */ - const factsFromUserMessageResponse = - await factsCaller.extractFromNewMessages({ - previousRunningSummary, - messagesSincePreviousRunningSummary: [], - newMessages: messagesSincePreviousRunningSummary, - }); - const insertedFactsFromUserMessage = await db.facts.createMany( - factsFromUserMessageResponse.object.facts.map((fact) => ({ - userId: "019900bb-61b3-7333-b760-b27784dfe33b", - sourceMessageId: insertedUserMessage.id, - content: fact, - })) - ); - - /** Produce a running summary of the conversation, and save that along - * with the model's response to the database. The new running summary is - * based on the previous running summary combined with the all messages - * since that summary was produced. */ - const runningSummaryResponse = - await messagesCaller.generateRunningSummary({ - messagesSincePreviousRunningSummary, - mainResponseContent: mainResponse.text, - previousRunningSummary, - }); - const insertedAssistantMessage = await db.messages.create({ - conversationId, - // content: mainResponse.text, - parts: [{ type: "text", text: mainResponse.text }], - runningSummary: runningSummaryResponse.text, - role: "assistant" as const, - index: messages.length, + + const insertedFactsFromAssistantMessage = await db.facts.createMany( + factsFromAssistantMessageResponse.object.facts.map((factContent) => ({ + userId: "019900bb-61b3-7333-b760-b27784dfe33b", + sourceMessageId: insertedAssistantMessage.id, + content: factContent, createdAt: new Date().toISOString(), + })) + ); + + const insertedFacts = [ + ...insertedFactsFromUserMessage, + ...insertedFactsFromAssistantMessage, + ]; + + // Emit status update + yield { + status: "generating_fact_triggers", + message: "Generating fact triggers...", + } as const; + + /** For each Fact produced in the two fact-extraction steps, generate + * FactTriggers and add them to the database, linking the FactTriggers + * with the Facts they came from. A FactTrigger is a natural language + * phrase that describes a situation in which it would be useful to invoke + * the Fact. (e.g., "When food preferences are discussed"). */ + for (const fact of insertedFacts) { + const factTriggers = await factTriggerCaller.generateFromFact({ + mainResponseContent: mainResponse.text, + previousRunningSummary, + messagesSincePreviousRunningSummary, + fact, }); - /** Extract Facts from the model's response, and add them to the database, - * linking the Facts with the messages they came from. */ - const factsFromAssistantMessageResponse = - await factsCaller.extractFromNewMessages({ - previousRunningSummary, - messagesSincePreviousRunningSummary, - newMessages: [ - { - role: "assistant" as const, - // content: mainResponse.text, - parts: [{ type: "text", text: mainResponse.text }], - }, - ], - }); - - const insertedFactsFromAssistantMessage = await db.facts.createMany( - factsFromAssistantMessageResponse.object.facts.map((factContent) => ({ - userId: "019900bb-61b3-7333-b760-b27784dfe33b", - sourceMessageId: insertedAssistantMessage.id, - content: factContent, + const insertedFactTriggers: Array> = + factTriggers.object.factTriggers.map((factTrigger) => ({ + sourceFactId: fact.id, + content: factTrigger, + priorityMultiplier: 1, + priorityMultiplierReason: "", + scopeConversationId: conversationId, createdAt: new Date().toISOString(), - })) - ); - - const insertedFacts = [ - ...insertedFactsFromUserMessage, - ...insertedFactsFromAssistantMessage, - ]; - - /** For each Fact produced in the two fact-extraction steps, generate - * FactTriggers and add them to the database, linking the FactTriggers - * with the Facts they came from. A FactTrigger is a natural language - * phrase that describes a situation in which it would be useful to invoke - * the Fact. (e.g., "When food preferences are discussed"). */ - for (const fact of insertedFacts) { - const factTriggers = await factTriggerCaller.generateFromFact({ - mainResponseContent: mainResponse.text, - previousRunningSummary, - messagesSincePreviousRunningSummary, - fact, - }); - const insertedFactTriggers: Array> = - factTriggers.object.factTriggers.map((factTrigger) => ({ - sourceFactId: fact.id, - content: factTrigger, - priorityMultiplier: 1, - priorityMultiplierReason: "", - scopeConversationId: conversationId, - createdAt: new Date().toISOString(), - })); - db.factTriggers.createMany(insertedFactTriggers); - } - - // await db.write(); + })); + await db.factTriggers.createMany(insertedFactTriggers); + } - return { + // Emit final result + yield { + status: "completed", + message: "Completed!", + result: { insertedAssistantMessage, insertedUserMessage, insertedFacts, - }; - } - ), + }, + } as const; + }), }); export const createCaller = createCallerFactory(chat); diff --git a/state.ts b/state.ts index 980d9fa..928bdf1 100644 --- a/state.ts +++ b/state.ts @@ -19,6 +19,8 @@ export const useStore = create()( facts: [], factTriggers: [], loading: false, + sendMessageStatus: null, + isSendingMessage: false, setConversationId: (conversationId) => set((stateDraft) => { stateDraft.selectedConversationId = conversationId; @@ -92,5 +94,13 @@ export const useStore = create()( set((stateDraft) => { stateDraft.loading = loading; }), + setSendMessageStatus: (status) => + set((stateDraft) => { + stateDraft.sendMessageStatus = status; + }), + setIsSendingMessage: (isSending) => + set((stateDraft) => { + stateDraft.isSendingMessage = isSending; + }), })), ); diff --git a/types.ts b/types.ts index 05463f6..0791fbc 100644 --- a/types.ts +++ b/types.ts @@ -9,6 +9,12 @@ export type OtherParameters = Omit< export type ConversationUI = Conversation & {}; +export type SendMessageStatus = { + status: string; + message: string; + result?: any; +}; + export type Store = { /** This is a string because Milvus sends it as a string, and the value * overflows the JS integer anyway. */ @@ -21,6 +27,8 @@ export type Store = { facts: Array; factTriggers: Array; loading: boolean; + sendMessageStatus: SendMessageStatus | null; + isSendingMessage: boolean; setConversationId: (conversationId: string) => void; setConversationTitle: (conversationTitle: string) => void; setConversations: (conversations: Array) => void; @@ -35,6 +43,8 @@ export type Store = { removeFact: (factId: string) => void; removeFactTrigger: (factTriggerId: string) => void; setLoading: (loading: boolean) => void; + setSendMessageStatus: (status: SendMessageStatus | null) => void; + setIsSendingMessage: (isSending: boolean) => void; }; /** The message while it's being typed in the input box. */