streaming response

master
Avraham Sakal 4 weeks ago
parent 1dddae6a05
commit ebfbb22525

@ -19,10 +19,20 @@ import {
import { usePageContext } from "vike-react/usePageContext"; import { usePageContext } from "vike-react/usePageContext";
import { useData } from "vike-react/useData"; import { useData } from "vike-react/useData";
import type { Data } from "./+data"; import type { Data } from "./+data";
import type { CommittedMessage, DraftMessage } from "../../../types"; import type {
CommittedMessage,
DraftMessage,
OtherParameters,
} from "../../../types";
import Markdown from "react-markdown"; import Markdown from "react-markdown";
import { IconTrash, IconEdit, IconCheck, IconX } from "@tabler/icons-react"; import {
import { useTRPC } from "../../../trpc/client"; IconTrash,
IconEdit,
IconCheck,
IconX,
IconLoaderQuarter,
} from "@tabler/icons-react";
import { useTRPC, useTRPCClient } from "../../../trpc/client";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import type { Conversation } from "../../../database/common"; import type { Conversation } from "../../../database/common";
@ -49,6 +59,7 @@ export default function ChatPage() {
const setParameters = useStore((state) => state.setParameters); const setParameters = useStore((state) => state.setParameters);
const setLoading = useStore((state) => state.setLoading); const setLoading = useStore((state) => state.setLoading);
const trpc = useTRPC(); const trpc = useTRPC();
const trpcClient = useTRPCClient();
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const messagesResult = useQuery( const messagesResult = useQuery(
@ -334,84 +345,95 @@ export default function ChatPage() {
}) })
); );
const sendMessage = useMutation( // Get state from Zustand store
trpc.chat.sendMessage.mutationOptions({ const sendMessageStatus = useStore((state) => state.sendMessageStatus);
onMutate: async ({ const isSendingMessage = useStore((state) => state.isSendingMessage);
const setSendMessageStatus = useStore((state) => state.setSendMessageStatus);
const setIsSendingMessage = useStore((state) => state.setIsSendingMessage);
// Function to send message using subscription
const sendSubscriptionMessage = async ({
conversationId, conversationId,
messages, messages,
systemPrompt, systemPrompt,
parameters, parameters,
}: {
conversationId: string;
messages: Array<DraftMessage | CommittedMessage>;
systemPrompt: string;
parameters: OtherParameters;
}) => { }) => {
/** Cancel affected queries that may be in-flight: */ setIsSendingMessage(true);
await queryClient.cancelQueries({ setSendMessageStatus(null);
queryKey: trpc.chat.messages.fetchByConversationId.queryKey({
conversationId, try {
}), // Create an abort controller for the subscription
}); const abortController = new AbortController();
/** Optimistically update the affected queries in react-query's cache: */
const previousMessages: Array<CommittedMessage> | undefined = // Start the subscription
await queryClient.getQueryData( const subscription = trpcClient.chat.sendMessage.subscribe(
trpc.chat.messages.fetchByConversationId.queryKey({
conversationId,
})
);
if (!previousMessages) {
return {
previousMessages: [],
newMessages: [],
};
}
const newMessages: Array<CommittedMessage> = [
...previousMessages,
{ {
/** placeholder id; will be overwritten when we get the true id from the backend */
id: nanoid(),
conversationId,
// content: messages[messages.length - 1].content,
// role: "user" as const,
...messages[messages.length - 1],
index: previousMessages.length,
createdAt: new Date().toISOString(),
} as CommittedMessage,
];
queryClient.setQueryData(
trpc.chat.messages.fetchByConversationId.queryKey({
conversationId, conversationId,
}), messages,
newMessages systemPrompt,
); parameters,
return { previousMessages, newMessages };
}, },
onSettled: async (data, variables, context) => { {
await queryClient.invalidateQueries({ signal: abortController.signal,
onData: (data) => {
setSendMessageStatus(data);
// If we've completed, update the UI and invalidate queries
if (data.status === "completed") {
setIsSendingMessage(false);
// Invalidate queries to refresh the data
queryClient.invalidateQueries({
queryKey: trpc.chat.messages.fetchByConversationId.queryKey({ queryKey: trpc.chat.messages.fetchByConversationId.queryKey({
conversationId, conversationId,
}), }),
}); });
await queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: trpc.chat.facts.fetchByConversationId.queryKey({ queryKey: trpc.chat.facts.fetchByConversationId.queryKey({
conversationId, conversationId,
}), }),
}); });
await queryClient.invalidateQueries({ queryClient.invalidateQueries({
queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey({ queryKey: trpc.chat.factTriggers.fetchByConversationId.queryKey(
{
conversationId, conversationId,
}), }
),
}); });
} else {
setSendMessageStatus(data);
}
}, },
onError: async (error, variables, context) => { onError: (error) => {
console.error(error); console.error("Subscription error:", error);
if (!context) return; setIsSendingMessage(false);
queryClient.setQueryData( setSendMessageStatus({
trpc.chat.messages.fetchByConversationId.queryKey({ status: "error",
conversationId, message: "An error occurred while sending the message",
}), });
context.previousMessages
);
}, },
}) }
); );
// Return a function to unsubscribe if needed
return () => {
abortController.abort();
subscription.unsubscribe();
};
} catch (error) {
console.error("Failed to start subscription:", error);
setIsSendingMessage(false);
setSendMessageStatus({
status: "error",
message: "Failed to start message sending process",
});
}
};
// State for editing facts // State for editing facts
const [editingFactId, setEditingFactId] = useState<string | null>(null); const [editingFactId, setEditingFactId] = useState<string | null>(null);
const [editingFactContent, setEditingFactContent] = useState(""); const [editingFactContent, setEditingFactContent] = useState("");
@ -483,6 +505,8 @@ export default function ChatPage() {
}); });
}} }}
/> />
{isSendingMessage && <IconLoaderQuarter size={16} stroke={1.5} />}
{sendMessageStatus && <span>{sendMessageStatus.message}</span>}
</div> </div>
<Tabs defaultValue="message"> <Tabs defaultValue="message">
<Tabs.List> <Tabs.List>
@ -504,7 +528,7 @@ export default function ChatPage() {
if (e.key === "Enter") { if (e.key === "Enter") {
e.preventDefault(); e.preventDefault();
setLoading(true); setLoading(true);
await sendMessage.mutateAsync({ await sendSubscriptionMessage({
conversationId, conversationId,
messages: [ messages: [
...(messages || []), ...(messages || []),

@ -74,10 +74,9 @@ export const chat = router({
parameters: OtherParameters; parameters: OtherParameters;
} }
) )
.mutation( .subscription(async function* ({
async ({
input: { conversationId, messages, systemPrompt, parameters }, input: { conversationId, messages, systemPrompt, parameters },
}) => { }) {
/** TODO: Save all unsaved messages (i.e. those without an `id`) to the /** TODO: Save all unsaved messages (i.e. those without an `id`) to the
* database. Is this dangerous? Can an attacker just send a bunch of * database. Is this dangerous? Can an attacker just send a bunch of
* messages, omitting the ids, causing me to save a bunch of them to the * messages, omitting the ids, causing me to save a bunch of them to the
@ -95,6 +94,13 @@ export const chat = router({
const messagesSincePreviousRunningSummary = messages.slice( const messagesSincePreviousRunningSummary = messages.slice(
previousRunningSummaryIndex + 1 previousRunningSummaryIndex + 1
); );
// Emit status update
yield {
status: "saving_user_message",
message: "Saving user message...",
} as const;
/** Save the incoming message to the database. */ /** Save the incoming message to the database. */
const insertedUserMessage = await db.messages.create({ const insertedUserMessage = await db.messages.create({
conversationId, conversationId,
@ -105,6 +111,12 @@ export const chat = router({
createdAt: new Date().toISOString(), createdAt: new Date().toISOString(),
}); });
// Emit status update
yield {
status: "generating_response",
message: "Generating AI response...",
} as const;
/** Generate a new message from the model, but hold-off on adding it to /** Generate a new message from the model, but hold-off on adding it to
* the database until we produce the associated running-summary, below. * the database until we produce the associated running-summary, below.
* The model should be given the conversation summary thus far, and of * The model should be given the conversation summary thus far, and of
@ -138,6 +150,13 @@ export const chat = router({
tools: undefined, tools: undefined,
...parameters, ...parameters,
}); });
// Emit status update
yield {
status: "extracting_facts_from_user",
message: "Extracting facts from user message...",
} as const;
/** Extract Facts from the user's message, and add them to the database, /** Extract Facts from the user's message, and add them to the database,
* linking the Facts with the messages they came from. (Yes, this should * linking the Facts with the messages they came from. (Yes, this should
* be done *after* the model response, not before; because when we run a * be done *after* the model response, not before; because when we run a
@ -160,6 +179,12 @@ export const chat = router({
})) }))
); );
// Emit status update
yield {
status: "generating_summary",
message: "Generating conversation summary...",
} as const;
/** Produce a running summary of the conversation, and save that along /** Produce a running summary of the conversation, and save that along
* with the model's response to the database. The new running summary is * with the model's response to the database. The new running summary is
* based on the previous running summary combined with the all messages * based on the previous running summary combined with the all messages
@ -179,6 +204,13 @@ export const chat = router({
index: messages.length, index: messages.length,
createdAt: new Date().toISOString(), createdAt: new Date().toISOString(),
}); });
// Emit status update
yield {
status: "extracting_facts_from_assistant",
message: "Extracting facts from assistant response...",
} as const;
/** Extract Facts from the model's response, and add them to the database, /** Extract Facts from the model's response, and add them to the database,
* linking the Facts with the messages they came from. */ * linking the Facts with the messages they came from. */
const factsFromAssistantMessageResponse = const factsFromAssistantMessageResponse =
@ -208,6 +240,12 @@ export const chat = router({
...insertedFactsFromAssistantMessage, ...insertedFactsFromAssistantMessage,
]; ];
// Emit status update
yield {
status: "generating_fact_triggers",
message: "Generating fact triggers...",
} as const;
/** For each Fact produced in the two fact-extraction steps, generate /** For each Fact produced in the two fact-extraction steps, generate
* FactTriggers and add them to the database, linking the FactTriggers * FactTriggers and add them to the database, linking the FactTriggers
* with the Facts they came from. A FactTrigger is a natural language * with the Facts they came from. A FactTrigger is a natural language
@ -229,18 +267,20 @@ export const chat = router({
scopeConversationId: conversationId, scopeConversationId: conversationId,
createdAt: new Date().toISOString(), createdAt: new Date().toISOString(),
})); }));
db.factTriggers.createMany(insertedFactTriggers); await db.factTriggers.createMany(insertedFactTriggers);
} }
// await db.write(); // Emit final result
yield {
return { status: "completed",
message: "Completed!",
result: {
insertedAssistantMessage, insertedAssistantMessage,
insertedUserMessage, insertedUserMessage,
insertedFacts, insertedFacts,
}; },
} } as const;
), }),
}); });
export const createCaller = createCallerFactory(chat); export const createCaller = createCallerFactory(chat);

@ -19,6 +19,8 @@ export const useStore = create<Store>()(
facts: [], facts: [],
factTriggers: [], factTriggers: [],
loading: false, loading: false,
sendMessageStatus: null,
isSendingMessage: false,
setConversationId: (conversationId) => setConversationId: (conversationId) =>
set((stateDraft) => { set((stateDraft) => {
stateDraft.selectedConversationId = conversationId; stateDraft.selectedConversationId = conversationId;
@ -92,5 +94,13 @@ export const useStore = create<Store>()(
set((stateDraft) => { set((stateDraft) => {
stateDraft.loading = loading; stateDraft.loading = loading;
}), }),
setSendMessageStatus: (status) =>
set((stateDraft) => {
stateDraft.sendMessageStatus = status;
}),
setIsSendingMessage: (isSending) =>
set((stateDraft) => {
stateDraft.isSendingMessage = isSending;
}),
})), })),
); );

@ -9,6 +9,12 @@ export type OtherParameters = Omit<
export type ConversationUI = Conversation & {}; export type ConversationUI = Conversation & {};
export type SendMessageStatus = {
status: string;
message: string;
result?: any;
};
export type Store = { export type Store = {
/** This is a string because Milvus sends it as a string, and the value /** This is a string because Milvus sends it as a string, and the value
* overflows the JS integer anyway. */ * overflows the JS integer anyway. */
@ -21,6 +27,8 @@ export type Store = {
facts: Array<Fact>; facts: Array<Fact>;
factTriggers: Array<FactTrigger>; factTriggers: Array<FactTrigger>;
loading: boolean; loading: boolean;
sendMessageStatus: SendMessageStatus | null;
isSendingMessage: boolean;
setConversationId: (conversationId: string) => void; setConversationId: (conversationId: string) => void;
setConversationTitle: (conversationTitle: string) => void; setConversationTitle: (conversationTitle: string) => void;
setConversations: (conversations: Array<ConversationUI>) => void; setConversations: (conversations: Array<ConversationUI>) => void;
@ -35,6 +43,8 @@ export type Store = {
removeFact: (factId: string) => void; removeFact: (factId: string) => void;
removeFactTrigger: (factTriggerId: string) => void; removeFactTrigger: (factTriggerId: string) => void;
setLoading: (loading: boolean) => void; setLoading: (loading: boolean) => void;
setSendMessageStatus: (status: SendMessageStatus | null) => void;
setIsSendingMessage: (isSending: boolean) => void;
}; };
/** The message while it's being typed in the input box. */ /** The message while it's being typed in the input box. */

Loading…
Cancel
Save