diff --git a/pages/chat/@id/+Page.tsx b/pages/chat/@id/+Page.tsx index 624ebed..929704f 100644 --- a/pages/chat/@id/+Page.tsx +++ b/pages/chat/@id/+Page.tsx @@ -20,6 +20,7 @@ import { useData } from "vike-react/useData"; import type { Data } from "./+data"; import type { CommittedMessage, DraftMessage } from "../../../types"; import Markdown from "react-markdown"; +import { IconTrash } from "@tabler/icons-react"; export default function ChatPage() { const pageContext = usePageContext(); @@ -40,6 +41,7 @@ export default function ChatPage() { const setSystemPrompt = useStore((state) => state.setSystemPrompt); const setParameters = useStore((state) => state.setParameters); const setFacts = useStore((state) => state.setFacts); + const removeFact = useStore((state) => state.removeFact); const setLoading = useStore((state) => state.setLoading); const { @@ -72,6 +74,11 @@ export default function ChatPage() { setFacts(initialFacts); }, [initialFacts, setFacts]); + async function handleDeleteFact(factId: string) { + removeFact(factId); + await trpc.chat.deleteFact.mutate({ factId }); + } + return ( <>
@@ -169,7 +176,18 @@ export default function ChatPage() { {facts.map((fact) => ( - {fact.content} + + {fact.content}{" "} + { + e.stopPropagation(); + e.preventDefault(); + handleDeleteFact(fact.id); + }} + /> + ))} diff --git a/pages/chat/@id/+data.ts b/pages/chat/@id/+data.ts index a88b058..9d06411 100644 --- a/pages/chat/@id/+data.ts +++ b/pages/chat/@id/+data.ts @@ -12,5 +12,8 @@ export const data = async (pageContext: PageContextServer) => { const messages = await caller.fetchMessages({ conversationId: id, }); - return { conversation, messages, facts: [] }; + const facts = await caller.fetchFacts({ + conversationId: id, + }); + return { conversation, messages, facts }; }; diff --git a/pages/chat/trpc.ts b/pages/chat/trpc.ts index 72a5089..40deec8 100644 --- a/pages/chat/trpc.ts +++ b/pages/chat/trpc.ts @@ -395,6 +395,34 @@ export const chat = router({ }; }, ), + fetchFacts: publicProcedure + .input((x) => x as { conversationId: string }) + .query(async ({ input: { conversationId } }) => { + const conversationMessageIds = db.data.messages + .filter((m) => m.conversationId === conversationId) + .map((m) => m.id); + const rows = await db.data.facts.filter((f) => + conversationMessageIds.includes(f.sourceMessageId), + ); + return rows as Array; + }), + deleteFact: publicProcedure + .input( + (x) => + x as { + factId: string; + }, + ) + .mutation(async ({ input: { factId } }) => { + const deletedFact = db.data.facts.find((fact) => fact.id === factId); + if (!deletedFact) throw new Error("Fact not found"); + db.data.facts.splice( + db.data.facts.findIndex((fact) => fact.id === factId), + 1, + ); + db.write(); + return { ok: true }; + }), }); export const createCaller = createCallerFactory(chat); diff --git a/state.ts b/state.ts index f27d872..24a62f2 100644 --- a/state.ts +++ b/state.ts @@ -69,6 +69,13 @@ export const useStore = create()( set((stateDraft) => { stateDraft.facts = facts; }), + removeFact: (factId) => + set((stateDraft) => { + stateDraft.facts.splice( + stateDraft.facts.findIndex((fact) => fact.id === factId), + 1, + ); + }), setLoading: (loading) => set((stateDraft) => { stateDraft.loading = loading; diff --git a/types.ts b/types.ts index af8ba38..3510618 100644 --- a/types.ts +++ b/types.ts @@ -30,6 +30,7 @@ export type Store = { setSystemPrompt: (systemPrompt: string) => void; setParameters: (parameters: OtherParameters) => void; setFacts: (facts: Array) => void; + removeFact: (factId: string) => void; setLoading: (loading: boolean) => void; };