expand use of react-query for state management

master
Avraham Sakal 1 month ago
parent 20f6f6918f
commit 84eaab08ea

@ -51,6 +51,7 @@ export interface MessageEntity extends Entity<CommittedMessage> {
export type FactTriggerEntity = Entity<FactTrigger> & { export type FactTriggerEntity = Entity<FactTrigger> & {
findByFactId: (factId: string) => Promise<Array<FactTrigger>>; findByFactId: (factId: string) => Promise<Array<FactTrigger>>;
findByConversationId: (conversationId: string) => Promise<Array<FactTrigger>>;
}; };
export interface ApplicationDatabase { export interface ApplicationDatabase {

@ -184,6 +184,16 @@ const factTriggers: FactTriggerEntity = {
.execute(); .execute();
return rows; return rows;
}, },
findByConversationId: async (conversationId) => {
const rows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.selectAll("fact_triggers")
.where("messages.conversationId", "=", conversationId)
.execute();
return rows;
},
}; };
const messages: MessageEntity = { const messages: MessageEntity = {
@ -232,12 +242,12 @@ const messages: MessageEntity = {
await dbClient.deleteFrom("messages").where("id", "=", id).execute(); await dbClient.deleteFrom("messages").where("id", "=", id).execute();
}, },
findByConversationId: async (conversationId) => { findByConversationId: async (conversationId) => {
const rows = await dbClient const rows = (await dbClient
.selectFrom("messages") .selectFrom("messages")
.selectAll() .selectAll()
.where("conversationId", "=", conversationId) .where("conversationId", "=", conversationId)
.execute(); .execute()) as Array<CommittedMessage>;
return rows as Array<CommittedMessage>; return rows;
}, },
}; };

@ -22,7 +22,10 @@ import type { Data } from "./+data";
import type { CommittedMessage, DraftMessage } from "../../../types"; import type { CommittedMessage, DraftMessage } from "../../../types";
import Markdown from "react-markdown"; import Markdown from "react-markdown";
import { IconTrash, IconEdit, IconCheck, IconX } from "@tabler/icons-react"; import { IconTrash, IconEdit, IconCheck, IconX } from "@tabler/icons-react";
import { useTRPCClient } from "../../../trpc/client"; import { useTRPC } from "../../../trpc/client";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { nanoid } from "nanoid";
import type { Conversation } from "../../../database/common";
export default function ChatPage() { export default function ChatPage() {
const pageContext = usePageContext(); const pageContext = usePageContext();
@ -30,31 +33,384 @@ export default function ChatPage() {
const { const {
conversation, conversation,
messages: initialMessages, // messages: initialMessages,
facts: initialFacts, // facts: initialFacts,
factTriggers: initialFactTriggers, // factTriggers: initialFactTriggers,
} = useData<Data>(); } = useData<Data>();
const conversationTitle = conversation?.title; const conversationTitle = conversation?.title;
const messages = useStore((state) => state.messages);
const message = useStore((state) => state.message); const message = useStore((state) => state.message);
const systemPrompt = useStore((state) => state.systemPrompt); const systemPrompt = useStore((state) => state.systemPrompt);
const parameters = useStore((state) => state.parameters); const parameters = useStore((state) => state.parameters);
const facts = useStore((state) => state.facts);
const factTriggers = useStore((state) => state.factTriggers);
const loading = useStore((state) => state.loading); const loading = useStore((state) => state.loading);
const setConversationId = useStore((state) => state.setConversationId);
const setConversationTitle = useStore((state) => state.setConversationTitle);
const setMessages = useStore((state) => state.setMessages);
const setMessage = useStore((state) => state.setMessage); const setMessage = useStore((state) => state.setMessage);
const setSystemPrompt = useStore((state) => state.setSystemPrompt); const setSystemPrompt = useStore((state) => state.setSystemPrompt);
const setParameters = useStore((state) => state.setParameters); const setParameters = useStore((state) => state.setParameters);
const setFacts = useStore((state) => state.setFacts);
const setFactTriggers = useStore((state) => state.setFactTriggers);
const removeFact = useStore((state) => state.removeFact);
const removeFactTrigger = useStore((state) => state.removeFactTrigger);
const setLoading = useStore((state) => state.setLoading); const setLoading = useStore((state) => state.setLoading);
const trpc = useTRPCClient(); const trpc = useTRPC();
const queryClient = useQueryClient();
const messagesResult = useQuery(
trpc.chat.messages.fetchByConversationId.queryOptions({
conversationId,
})
);
const messages: Array<CommittedMessage> | undefined =
messagesResult.data?.map((m) => ({
...m,
parts: m.parts.filter((p) => p.type === "text"),
})) || [];
const facts = useQuery(
trpc.chat.facts.fetchByConversationId.queryOptions({
conversationId,
})
);
const factTriggers = useQuery(
trpc.chat.factTriggers.fetchByConversationId.queryOptions({
conversationId,
})
);
const deleteFact = useMutation(
trpc.chat.facts.deleteOne.mutationOptions({
onMutate: async ({ factId: factIdToDelete }) => {
/** Cancel affected queries that may be in-flight: */
await queryClient.cancelQueries({
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
});
/** Optimistically update the affected queries in react-query's cache: */
const previousFacts = await queryClient.getQueryData(
trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
})
);
if (!previousFacts) {
return {
previousFacts: [],
newFacts: [],
};
}
const newFacts = previousFacts.filter((f) => f.id !== factIdToDelete);
queryClient.setQueryData(
trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
newFacts
);
return { previousFacts, newFacts };
},
onSettled: async (data, variables, context) => {
await queryClient.invalidateQueries({
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
});
},
onError: async (error, variables, context) => {
console.error(error);
if (!context) return;
queryClient.setQueryData(
trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
context.previousFacts
);
},
})
);
const updateFact = useMutation(
trpc.chat.facts.update.mutationOptions({
onMutate: async ({ factId, content }) => {
/** Cancel affected queries that may be in-flight: */
await queryClient.cancelQueries({
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
});
/** Optimistically update the affected queries in react-query's cache: */
const previousFacts = await queryClient.getQueryData(
trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
})
);
if (!previousFacts) {
return {
previousFacts: [],
newFacts: [],
};
}
const newFacts = previousFacts.map((f) =>
f.id === factId ? { ...f, content } : f
);
queryClient.setQueryData(
trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
newFacts
);
return { previousFacts, newFacts };
},
onSettled: async (data, variables, context) => {
await queryClient.invalidateQueries({
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
});
},
onError: async (error, variables, context) => {
console.error(error);
if (!context) return;
queryClient.setQueryData(
trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
context.previousFacts
);
},
})
);
const deleteFactTrigger = useMutation(
trpc.chat.factTriggers.deleteOne.mutationOptions({
onMutate: async ({ factTriggerId: factTriggerIdToDelete }) => {
/** Cancel affected queries that may be in-flight: */
await queryClient.cancelQueries({
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
});
/** Optimistically update the affected queries in react-query's cache: */
const previousFactTriggers = await queryClient.getQueryData(
trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
})
);
if (!previousFactTriggers) {
return {
previousFactTriggers: [],
newFactTriggers: [],
};
}
const newFactTriggers = previousFactTriggers.filter(
(ft) => ft.id !== factTriggerIdToDelete
);
queryClient.setQueryData(
trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
newFactTriggers
);
return { previousFactTriggers, newFactTriggers };
},
onSettled: async (data, variables, context) => {
await queryClient.invalidateQueries({
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
});
},
onError: async (error, variables, context) => {
console.error(error);
if (!context) return;
queryClient.setQueryData(
trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
context.previousFactTriggers
);
},
})
);
const updateFactTrigger = useMutation(
trpc.chat.factTriggers.update.mutationOptions({
onMutate: async ({ factTriggerId, content }) => {
/** Cancel affected queries that may be in-flight: */
await queryClient.cancelQueries({
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
});
/** Optimistically update the affected queries in react-query's cache: */
const previousFactTriggers = await queryClient.getQueryData(
trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
})
);
if (!previousFactTriggers) {
return {
previousFactTriggers: [],
newFactTriggers: [],
};
}
const newFactTriggers = previousFactTriggers.map((ft) =>
ft.id === factTriggerId ? { ...ft, content } : ft
);
queryClient.setQueryData(
trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
newFactTriggers
);
return { previousFactTriggers, newFactTriggers };
},
onSettled: async (data, variables, context) => {
await queryClient.invalidateQueries({
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
});
},
onError: async (error, variables, context) => {
console.error(error);
if (!context) return;
queryClient.setQueryData(
trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
context.previousFactTriggers
);
},
})
);
const updateConversationTitle = useMutation(
trpc.chat.conversations.updateTitle.mutationOptions({
onMutate: async ({ id, title }) => {
/** Cancel affected queries that may be in-flight: */
await queryClient.cancelQueries({
queryKey: trpc.chat.conversations.fetchAll.queryKey(),
});
/** Optimistically update the affected queries in react-query's cache: */
const previousConversations = await queryClient.getQueryData(
trpc.chat.conversations.fetchAll.queryKey()
);
if (!previousConversations) {
return {
previousConversations: [],
newConversations: null,
};
}
const newConversations: Array<Conversation> = [
...previousConversations,
{
...conversation,
title,
} as Conversation,
];
queryClient.setQueryData(
trpc.chat.conversations.fetchAll.queryKey(),
newConversations
);
return { previousConversations, newConversations };
},
onSettled: async (data, variables, context) => {
await queryClient.invalidateQueries({
queryKey: trpc.chat.conversations.fetchOne.queryKey({
id: conversationId,
}),
});
await queryClient.invalidateQueries({
queryKey: trpc.chat.conversations.fetchAll.queryKey(),
});
},
onError: async (error, variables, context) => {
console.error(error);
if (!context) return;
queryClient.setQueryData(
trpc.chat.conversations.fetchAll.queryKey(),
context.previousConversations
);
},
})
);
const sendMessage = useMutation(
trpc.chat.sendMessage.mutationOptions({
onMutate: async ({
conversationId,
messages,
systemPrompt,
parameters,
}) => {
/** Cancel affected queries that may be in-flight: */
await queryClient.cancelQueries({
queryKey: trpc.chat.messages.fetchByConversationId.queryKey({
conversationId,
}),
});
/** Optimistically update the affected queries in react-query's cache: */
const previousMessages: Array<CommittedMessage> | undefined =
await queryClient.getQueryData(
trpc.chat.messages.fetchByConversationId.queryKey({
conversationId,
})
);
if (!previousMessages) {
return {
previousMessages: [],
newMessages: [],
};
}
const newMessages: Array<CommittedMessage> = [
...previousMessages,
{
/** placeholder id; will be overwritten when we get the true id from the backend */
id: nanoid(),
conversationId,
// content: messages[messages.length - 1].content,
// role: "user" as const,
...messages[messages.length - 1],
index: previousMessages.length,
createdAt: new Date().toISOString(),
} as CommittedMessage,
];
queryClient.setQueryData(
trpc.chat.messages.fetchByConversationId.queryKey({
conversationId,
}),
newMessages
);
return { previousMessages, newMessages };
},
onSettled: async (data, variables, context) => {
await queryClient.invalidateQueries({
queryKey: trpc.chat.messages.fetchByConversationId.queryKey({
conversationId,
}),
});
await queryClient.invalidateQueries({
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
conversationId,
}),
});
await queryClient.invalidateQueries({
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({
conversationId,
}),
});
},
onError: async (error, variables, context) => {
console.error(error);
if (!context) return;
queryClient.setQueryData(
trpc.chat.messages.fetchByConversationId.queryKey({
conversationId,
}),
context.previousMessages
);
},
})
);
// State for editing facts // State for editing facts
const [editingFactId, setEditingFactId] = useState<string | null>(null); const [editingFactId, setEditingFactId] = useState<string | null>(null);
@ -91,69 +447,23 @@ export default function ChatPage() {
}; };
}, [editingFactId, editingFactTriggerId]); }, [editingFactId, editingFactTriggerId]);
useEffect(() => {
setConversationId(conversationId);
}, [conversationId, setConversationId]);
useEffect(() => {
if (conversation?.id && conversation?.title) {
setConversationId(conversation.id);
setConversationTitle(conversation.title);
}
}, [
conversation?.id,
conversation?.title,
setConversationId,
setConversationTitle,
]);
useEffect(() => {
setMessages(initialMessages);
}, [initialMessages, setMessages]);
useEffect(() => {
setFacts(initialFacts);
}, [initialFacts, setFacts]);
useEffect(() => {
setFactTriggers(initialFactTriggers);
}, [initialFactTriggers, setFactTriggers]);
async function handleDeleteFact(factId: string) { async function handleDeleteFact(factId: string) {
removeFact(factId); await deleteFact.mutateAsync({ factId });
await trpc.chat.facts.deleteOne.mutate({ factId });
} }
async function handleUpdateFact(factId: string, content: string) { async function handleUpdateFact(factId: string, content: string) {
// Update the local state first await updateFact.mutateAsync({ factId, content });
setFacts(
facts.map((fact) => (fact.id === factId ? { ...fact, content } : fact))
);
// Then update the database
await trpc.chat.facts.update.mutate({ factId, content });
} }
async function handleDeleteFactTrigger(factTriggerId: string) { async function handleDeleteFactTrigger(factTriggerId: string) {
removeFactTrigger(factTriggerId); await deleteFactTrigger.mutateAsync({ factTriggerId });
await trpc.chat.factTriggers.deleteOne.mutate({ factTriggerId });
} }
async function handleUpdateFactTrigger( async function handleUpdateFactTrigger(
factTriggerId: string, factTriggerId: string,
content: string content: string
) { ) {
// Update the local state first await updateFactTrigger.mutateAsync({ factTriggerId, content });
setFactTriggers(
factTriggers.map((factTrigger) =>
factTrigger.id === factTriggerId
? { ...factTrigger, content }
: factTrigger
)
);
// Then update the database
await trpc.chat.factTriggers.update.mutate({ factTriggerId, content });
} }
return ( return (
@ -163,11 +473,11 @@ export default function ChatPage() {
<input <input
type="text" type="text"
defaultValue={conversationTitle || ""} defaultValue={conversationTitle || ""}
onChange={(e) => { // onChange={(e) => {
setConversationTitle(e.target.value); // setConversationTitle(e.target.value);
}} // }}
onBlur={(e) => { onBlur={(e) => {
trpc.chat.conversations.updateTitle.mutate({ updateConversationTitle.mutateAsync({
id: conversationId, id: conversationId,
title: e.target.value, title: e.target.value,
}); });
@ -183,7 +493,7 @@ export default function ChatPage() {
<Tabs.Tab value="fact-triggers">Fact Triggers</Tabs.Tab> <Tabs.Tab value="fact-triggers">Fact Triggers</Tabs.Tab>
</Tabs.List> </Tabs.List>
<Tabs.Panel value="message"> <Tabs.Panel value="message">
<Messages messages={messages} /> <Messages messages={messages || []} />
<Textarea <Textarea
resize="vertical" resize="vertical"
placeholder="Type your message here..." placeholder="Type your message here..."
@ -193,48 +503,20 @@ export default function ChatPage() {
onKeyDown={async (e) => { onKeyDown={async (e) => {
if (e.key === "Enter") { if (e.key === "Enter") {
e.preventDefault(); e.preventDefault();
const messagesWithNewUserMessage = [
...messages,
{
role: "user" as const,
parts: [{ type: "text", text: message }],
} as DraftMessage,
];
setMessages(messagesWithNewUserMessage);
setLoading(true); setLoading(true);
const response = await trpc.chat.sendMessage.mutate({ await sendMessage.mutateAsync({
conversationId, conversationId,
messages: messagesWithNewUserMessage, messages: [
...(messages || []),
{
role: "user" as const,
parts: [{ type: "text", text: message }],
} as DraftMessage,
],
systemPrompt, systemPrompt,
parameters, parameters,
}); });
const messagesWithAssistantMessage = [
...messages,
{
id: response.insertedUserMessage?.id,
conversationId,
role: "user" as const,
// content: message,
parts: [{ type: "text", text: message }],
index: response.insertedUserMessage?.index,
runningSummary: undefined,
} as CommittedMessage,
{
id: response.insertedAssistantMessage?.id,
conversationId,
role: "assistant" as const,
// content: response.insertedAssistantMessage?.content,
// parts: [{ type: "text", text: response.insertedAssistantMessage?.content }],
parts: response.insertedAssistantMessage?.parts,
index: response.insertedAssistantMessage?.index,
runningSummary:
response.insertedAssistantMessage?.runningSummary ||
undefined,
} as CommittedMessage,
];
setMessages(messagesWithAssistantMessage);
setMessage(""); setMessage("");
setFacts(response.insertedFacts);
setLoading(false); setLoading(false);
} }
}} }}
@ -259,7 +541,7 @@ export default function ChatPage() {
</Tabs.Panel> </Tabs.Panel>
<Tabs.Panel value="facts"> <Tabs.Panel value="facts">
<List> <List>
{facts.map((fact) => ( {facts.data?.map((fact) => (
<List.Item key={fact.id}> <List.Item key={fact.id}>
{editingFactId === fact.id ? ( {editingFactId === fact.id ? (
<Group wrap="nowrap" className="editing-fact"> <Group wrap="nowrap" className="editing-fact">
@ -318,7 +600,7 @@ export default function ChatPage() {
</Tabs.Panel> </Tabs.Panel>
<Tabs.Panel value="fact-triggers"> <Tabs.Panel value="fact-triggers">
<List> <List>
{factTriggers.map((factTrigger) => ( {factTriggers.data?.map((factTrigger) => (
<List.Item key={factTrigger.id}> <List.Item key={factTrigger.id}>
{editingFactTriggerId === factTrigger.id ? ( {editingFactTriggerId === factTrigger.id ? (
<Group wrap="nowrap" className="editing-fact-trigger"> <Group wrap="nowrap" className="editing-fact-trigger">

@ -6,22 +6,23 @@ export type Data = Awaited<ReturnType<typeof data>>;
export const data = async (pageContext: PageContextServer) => { export const data = async (pageContext: PageContextServer) => {
const { id } = pageContext.routeParams; const { id } = pageContext.routeParams;
const caller = createCaller({}); const caller = createCaller({});
const conversation = await caller.conversations.fetchOne({
id, const [
}); conversation,
const messages = await caller.conversations.fetchMessages({ // messages,
conversationId: id, // facts,
}); // factTriggers
const facts = await caller.facts.fetchByConversationId({ ] = await Promise.all([
conversationId: id, caller.conversations.fetchOne({ id }),
}); // caller.conversations.fetchMessages({ conversationId: id }),
// caller.facts.fetchByConversationId({ conversationId: id }),
// Fetch all fact triggers for the conversation's facts // caller.factTriggers.fetchByConversationId({ conversationId: id }),
const factTriggerPromises = facts.map(fact => ]);
caller.factTriggers.fetchByFactId({ factId: fact.id })
); return {
const factTriggersArrays = await Promise.all(factTriggerPromises); conversation,
const factTriggers = factTriggersArrays.flat(); // messages,
// facts,
return { conversation, messages, facts, factTriggers }; // factTriggers
};
}; };

@ -64,6 +64,11 @@ export const factTriggers = router({
.query(async ({ input: { factId } }) => { .query(async ({ input: { factId } }) => {
return db.factTriggers.findByFactId(factId); return db.factTriggers.findByFactId(factId);
}), }),
fetchByConversationId: publicProcedure
.input((x) => x as { conversationId: string })
.query(async ({ input: { conversationId } }) => {
return await db.factTriggers.findByConversationId(conversationId);
}),
deleteOne: publicProcedure deleteOne: publicProcedure
.input( .input(
(x) => (x) =>

Loading…
Cancel
Save