diff --git a/database/common.ts b/database/common.ts index a5c66a5..0f9992b 100644 --- a/database/common.ts +++ b/database/common.ts @@ -51,6 +51,7 @@ export interface MessageEntity extends Entity { export type FactTriggerEntity = Entity & { findByFactId: (factId: string) => Promise>; + findByConversationId: (conversationId: string) => Promise>; }; export interface ApplicationDatabase { diff --git a/database/postgres.ts b/database/postgres.ts index 31f268b..c3521d6 100644 --- a/database/postgres.ts +++ b/database/postgres.ts @@ -184,6 +184,16 @@ const factTriggers: FactTriggerEntity = { .execute(); return rows; }, + findByConversationId: async (conversationId) => { + const rows = await dbClient + .selectFrom("fact_triggers") + .innerJoin("facts", "facts.id", "fact_triggers.sourceFactId") + .innerJoin("messages", "messages.id", "facts.sourceMessageId") + .selectAll("fact_triggers") + .where("messages.conversationId", "=", conversationId) + .execute(); + return rows; + }, }; const messages: MessageEntity = { @@ -232,12 +242,12 @@ const messages: MessageEntity = { await dbClient.deleteFrom("messages").where("id", "=", id).execute(); }, findByConversationId: async (conversationId) => { - const rows = await dbClient + const rows = (await dbClient .selectFrom("messages") .selectAll() .where("conversationId", "=", conversationId) - .execute(); - return rows as Array; + .execute()) as Array; + return rows; }, }; diff --git a/pages/chat/@id/+Page.tsx b/pages/chat/@id/+Page.tsx index 067ea6b..d3bbf18 100644 --- a/pages/chat/@id/+Page.tsx +++ b/pages/chat/@id/+Page.tsx @@ -22,7 +22,10 @@ import type { Data } from "./+data"; import type { CommittedMessage, DraftMessage } from "../../../types"; import Markdown from "react-markdown"; import { IconTrash, IconEdit, IconCheck, IconX } from "@tabler/icons-react"; -import { useTRPCClient } from "../../../trpc/client"; +import { useTRPC } from "../../../trpc/client"; +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { nanoid } from "nanoid"; +import type { Conversation } from "../../../database/common"; export default function ChatPage() { const pageContext = usePageContext(); @@ -30,31 +33,384 @@ export default function ChatPage() { const { conversation, - messages: initialMessages, - facts: initialFacts, - factTriggers: initialFactTriggers, + // messages: initialMessages, + // facts: initialFacts, + // factTriggers: initialFactTriggers, } = useData(); + const conversationTitle = conversation?.title; - const messages = useStore((state) => state.messages); const message = useStore((state) => state.message); const systemPrompt = useStore((state) => state.systemPrompt); const parameters = useStore((state) => state.parameters); - const facts = useStore((state) => state.facts); - const factTriggers = useStore((state) => state.factTriggers); const loading = useStore((state) => state.loading); - const setConversationId = useStore((state) => state.setConversationId); - const setConversationTitle = useStore((state) => state.setConversationTitle); - const setMessages = useStore((state) => state.setMessages); const setMessage = useStore((state) => state.setMessage); const setSystemPrompt = useStore((state) => state.setSystemPrompt); const setParameters = useStore((state) => state.setParameters); - const setFacts = useStore((state) => state.setFacts); - const setFactTriggers = useStore((state) => state.setFactTriggers); - const removeFact = useStore((state) => state.removeFact); - const removeFactTrigger = useStore((state) => state.removeFactTrigger); const setLoading = useStore((state) => state.setLoading); - const trpc = useTRPCClient(); + const trpc = useTRPC(); + const queryClient = useQueryClient(); + + const messagesResult = useQuery( + trpc.chat.messages.fetchByConversationId.queryOptions({ + conversationId, + }) + ); + const messages: Array | undefined = + messagesResult.data?.map((m) => ({ + ...m, + parts: m.parts.filter((p) => p.type === "text"), + })) || []; + + const facts = useQuery( + trpc.chat.facts.fetchByConversationId.queryOptions({ + conversationId, + }) + ); + const factTriggers = useQuery( + trpc.chat.factTriggers.fetchByConversationId.queryOptions({ + conversationId, + }) + ); + + const deleteFact = useMutation( + trpc.chat.facts.deleteOne.mutationOptions({ + onMutate: async ({ factId: factIdToDelete }) => { + /** Cancel affected queries that may be in-flight: */ + await queryClient.cancelQueries({ + queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + }); + + /** Optimistically update the affected queries in react-query's cache: */ + const previousFacts = await queryClient.getQueryData( + trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }) + ); + if (!previousFacts) { + return { + previousFacts: [], + newFacts: [], + }; + } + const newFacts = previousFacts.filter((f) => f.id !== factIdToDelete); + queryClient.setQueryData( + trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + newFacts + ); + + return { previousFacts, newFacts }; + }, + onSettled: async (data, variables, context) => { + await queryClient.invalidateQueries({ + queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + }); + }, + onError: async (error, variables, context) => { + console.error(error); + if (!context) return; + queryClient.setQueryData( + trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + context.previousFacts + ); + }, + }) + ); + const updateFact = useMutation( + trpc.chat.facts.update.mutationOptions({ + onMutate: async ({ factId, content }) => { + /** Cancel affected queries that may be in-flight: */ + await queryClient.cancelQueries({ + queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + }); + + /** Optimistically update the affected queries in react-query's cache: */ + const previousFacts = await queryClient.getQueryData( + trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }) + ); + if (!previousFacts) { + return { + previousFacts: [], + newFacts: [], + }; + } + const newFacts = previousFacts.map((f) => + f.id === factId ? { ...f, content } : f + ); + queryClient.setQueryData( + trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + newFacts + ); + + return { previousFacts, newFacts }; + }, + onSettled: async (data, variables, context) => { + await queryClient.invalidateQueries({ + queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + }); + }, + onError: async (error, variables, context) => { + console.error(error); + if (!context) return; + queryClient.setQueryData( + trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + context.previousFacts + ); + }, + }) + ); + + const deleteFactTrigger = useMutation( + trpc.chat.factTriggers.deleteOne.mutationOptions({ + onMutate: async ({ factTriggerId: factTriggerIdToDelete }) => { + /** Cancel affected queries that may be in-flight: */ + await queryClient.cancelQueries({ + queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + }); + + /** Optimistically update the affected queries in react-query's cache: */ + const previousFactTriggers = await queryClient.getQueryData( + trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }) + ); + if (!previousFactTriggers) { + return { + previousFactTriggers: [], + newFactTriggers: [], + }; + } + const newFactTriggers = previousFactTriggers.filter( + (ft) => ft.id !== factTriggerIdToDelete + ); + queryClient.setQueryData( + trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + newFactTriggers + ); + + return { previousFactTriggers, newFactTriggers }; + }, + onSettled: async (data, variables, context) => { + await queryClient.invalidateQueries({ + queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + }); + }, + onError: async (error, variables, context) => { + console.error(error); + if (!context) return; + queryClient.setQueryData( + trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + context.previousFactTriggers + ); + }, + }) + ); + const updateFactTrigger = useMutation( + trpc.chat.factTriggers.update.mutationOptions({ + onMutate: async ({ factTriggerId, content }) => { + /** Cancel affected queries that may be in-flight: */ + await queryClient.cancelQueries({ + queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + }); + /** Optimistically update the affected queries in react-query's cache: */ + const previousFactTriggers = await queryClient.getQueryData( + trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }) + ); + if (!previousFactTriggers) { + return { + previousFactTriggers: [], + newFactTriggers: [], + }; + } + const newFactTriggers = previousFactTriggers.map((ft) => + ft.id === factTriggerId ? { ...ft, content } : ft + ); + queryClient.setQueryData( + trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + newFactTriggers + ); + return { previousFactTriggers, newFactTriggers }; + }, + onSettled: async (data, variables, context) => { + await queryClient.invalidateQueries({ + queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + }); + }, + onError: async (error, variables, context) => { + console.error(error); + if (!context) return; + queryClient.setQueryData( + trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + context.previousFactTriggers + ); + }, + }) + ); + + const updateConversationTitle = useMutation( + trpc.chat.conversations.updateTitle.mutationOptions({ + onMutate: async ({ id, title }) => { + /** Cancel affected queries that may be in-flight: */ + await queryClient.cancelQueries({ + queryKey: trpc.chat.conversations.fetchAll.queryKey(), + }); + /** Optimistically update the affected queries in react-query's cache: */ + const previousConversations = await queryClient.getQueryData( + trpc.chat.conversations.fetchAll.queryKey() + ); + if (!previousConversations) { + return { + previousConversations: [], + newConversations: null, + }; + } + const newConversations: Array = [ + ...previousConversations, + { + ...conversation, + title, + } as Conversation, + ]; + queryClient.setQueryData( + trpc.chat.conversations.fetchAll.queryKey(), + newConversations + ); + return { previousConversations, newConversations }; + }, + onSettled: async (data, variables, context) => { + await queryClient.invalidateQueries({ + queryKey: trpc.chat.conversations.fetchOne.queryKey({ + id: conversationId, + }), + }); + await queryClient.invalidateQueries({ + queryKey: trpc.chat.conversations.fetchAll.queryKey(), + }); + }, + onError: async (error, variables, context) => { + console.error(error); + if (!context) return; + queryClient.setQueryData( + trpc.chat.conversations.fetchAll.queryKey(), + context.previousConversations + ); + }, + }) + ); + + const sendMessage = useMutation( + trpc.chat.sendMessage.mutationOptions({ + onMutate: async ({ + conversationId, + messages, + systemPrompt, + parameters, + }) => { + /** Cancel affected queries that may be in-flight: */ + await queryClient.cancelQueries({ + queryKey: trpc.chat.messages.fetchByConversationId.queryKey({ + conversationId, + }), + }); + /** Optimistically update the affected queries in react-query's cache: */ + const previousMessages: Array | undefined = + await queryClient.getQueryData( + trpc.chat.messages.fetchByConversationId.queryKey({ + conversationId, + }) + ); + if (!previousMessages) { + return { + previousMessages: [], + newMessages: [], + }; + } + const newMessages: Array = [ + ...previousMessages, + { + /** placeholder id; will be overwritten when we get the true id from the backend */ + id: nanoid(), + conversationId, + // content: messages[messages.length - 1].content, + // role: "user" as const, + ...messages[messages.length - 1], + index: previousMessages.length, + createdAt: new Date().toISOString(), + } as CommittedMessage, + ]; + queryClient.setQueryData( + trpc.chat.messages.fetchByConversationId.queryKey({ + conversationId, + }), + newMessages + ); + return { previousMessages, newMessages }; + }, + onSettled: async (data, variables, context) => { + await queryClient.invalidateQueries({ + queryKey: trpc.chat.messages.fetchByConversationId.queryKey({ + conversationId, + }), + }); + await queryClient.invalidateQueries({ + queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ + conversationId, + }), + }); + await queryClient.invalidateQueries({ + queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({ + conversationId, + }), + }); + }, + onError: async (error, variables, context) => { + console.error(error); + if (!context) return; + queryClient.setQueryData( + trpc.chat.messages.fetchByConversationId.queryKey({ + conversationId, + }), + context.previousMessages + ); + }, + }) + ); // State for editing facts const [editingFactId, setEditingFactId] = useState(null); @@ -91,69 +447,23 @@ export default function ChatPage() { }; }, [editingFactId, editingFactTriggerId]); - useEffect(() => { - setConversationId(conversationId); - }, [conversationId, setConversationId]); - - useEffect(() => { - if (conversation?.id && conversation?.title) { - setConversationId(conversation.id); - setConversationTitle(conversation.title); - } - }, [ - conversation?.id, - conversation?.title, - setConversationId, - setConversationTitle, - ]); - - useEffect(() => { - setMessages(initialMessages); - }, [initialMessages, setMessages]); - - useEffect(() => { - setFacts(initialFacts); - }, [initialFacts, setFacts]); - - useEffect(() => { - setFactTriggers(initialFactTriggers); - }, [initialFactTriggers, setFactTriggers]); - async function handleDeleteFact(factId: string) { - removeFact(factId); - await trpc.chat.facts.deleteOne.mutate({ factId }); + await deleteFact.mutateAsync({ factId }); } async function handleUpdateFact(factId: string, content: string) { - // Update the local state first - setFacts( - facts.map((fact) => (fact.id === factId ? { ...fact, content } : fact)) - ); - - // Then update the database - await trpc.chat.facts.update.mutate({ factId, content }); + await updateFact.mutateAsync({ factId, content }); } async function handleDeleteFactTrigger(factTriggerId: string) { - removeFactTrigger(factTriggerId); - await trpc.chat.factTriggers.deleteOne.mutate({ factTriggerId }); + await deleteFactTrigger.mutateAsync({ factTriggerId }); } async function handleUpdateFactTrigger( factTriggerId: string, content: string ) { - // Update the local state first - setFactTriggers( - factTriggers.map((factTrigger) => - factTrigger.id === factTriggerId - ? { ...factTrigger, content } - : factTrigger - ) - ); - - // Then update the database - await trpc.chat.factTriggers.update.mutate({ factTriggerId, content }); + await updateFactTrigger.mutateAsync({ factTriggerId, content }); } return ( @@ -163,11 +473,11 @@ export default function ChatPage() { { - setConversationTitle(e.target.value); - }} + // onChange={(e) => { + // setConversationTitle(e.target.value); + // }} onBlur={(e) => { - trpc.chat.conversations.updateTitle.mutate({ + updateConversationTitle.mutateAsync({ id: conversationId, title: e.target.value, }); @@ -183,7 +493,7 @@ export default function ChatPage() { Fact Triggers - +