diff --git a/layouts/LayoutDefault.tsx b/layouts/LayoutDefault.tsx index 2f9f0e3..515bbae 100644 --- a/layouts/LayoutDefault.tsx +++ b/layouts/LayoutDefault.tsx @@ -1,4 +1,5 @@ import "@mantine/core/styles.css"; +import { navigate } from "vike/client/router"; import { AppShell, Burger, @@ -15,6 +16,7 @@ import { IconCircle, IconCircleFilled, IconTrashFilled, + IconPlus, } from "@tabler/icons-react"; import { useDisclosure } from "@mantine/hooks"; import theme from "./theme.js"; @@ -36,7 +38,7 @@ export default function LayoutDefault({ const setConversations = useStore((state) => state.setConversations); const addConversation = useStore((state) => state.addConversation); const removeConversation = useStore((state) => state.removeConversation); - const conversationId = useStore((state) => state.conversationId); + const conversationId = useStore((state) => state.selectedConversationId); useEffect(() => { trpc.chat.listConversations.query().then((res) => { @@ -96,11 +98,25 @@ export default function LayoutDefault({ label="Chats" leftSection={} rightSection={ - + <> + { + trpc.chat.createConversation.mutate().then((res) => { + if (!res?.id) return; + addConversation(res); + navigate(`/chat/${res.id}`); + }); + }} + /> + + } variant="subtle" active={urlPathname.startsWith("/chat")} diff --git a/package.json b/package.json index ef03e0d..1cd1163 100644 --- a/package.json +++ b/package.json @@ -29,6 +29,7 @@ "ai": "^4.3.16", "dotenv": "^17.0.0", "hono": "^4.8.2", + "immer": "^10.1.1", "kysely": "^0.28.2", "pg": "^8.16.3", "react": "^19.1.0", diff --git a/pages/chat/@id/+Page.tsx b/pages/chat/@id/+Page.tsx index d0e80e2..6920b4b 100644 --- a/pages/chat/@id/+Page.tsx +++ b/pages/chat/@id/+Page.tsx @@ -15,7 +15,9 @@ import type { ConversationsId } from "../../../database/generated/public/Convers export default function ChatPage() { const pageContext = usePageContext(); const conversationId = Number(pageContext.routeParams.id) as ConversationsId; - const conversationTitle = useStore((state) => state.conversationTitle); + const conversationTitle = useStore( + (state) => state.conversations.find((c) => c.id === conversationId)?.title, + ); const messages = useStore((state) => state.messages); const message = useStore((state) => state.message); const systemPrompt = useStore((state) => state.systemPrompt); @@ -53,7 +55,7 @@ export default function ChatPage() { Conversation #{conversationId} - { setConversationTitle(e.target.value); }} @@ -88,7 +90,7 @@ export default function ChatPage() { ]; setMessages(messagesWithNewUserMessage); setLoading(true); - const response = await trpc.chat.sendMessage.query({ + const response = await trpc.chat.sendMessage.mutate({ messages: messagesWithNewUserMessage, systemPrompt, parameters, diff --git a/pages/chat/trpc.ts b/pages/chat/trpc.ts index 11d4699..e744014 100644 --- a/pages/chat/trpc.ts +++ b/pages/chat/trpc.ts @@ -84,7 +84,7 @@ export const chat = router({ parameters: OtherParameters; }, ) - .query(async ({ input: { messages, systemPrompt, parameters } }) => { + .mutation(async ({ input: { messages, systemPrompt, parameters } }) => { const response = await generateText({ model: openrouter("mistralai/mistral-nemo"), messages: [ diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 1ae3858..03171b9 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -62,6 +62,9 @@ importers: hono: specifier: ^4.8.2 version: 4.8.3 + immer: + specifier: ^10.1.1 + version: 10.1.1 kysely: specifier: ^0.28.2 version: 0.28.2 @@ -88,7 +91,7 @@ importers: version: 3.25.67 zustand: specifier: ^5.0.6 - version: 5.0.6(@types/react@19.1.8)(react@19.1.0)(use-sync-external-store@1.5.0(react@19.1.0)) + version: 5.0.6(@types/react@19.1.8)(immer@10.1.1)(react@19.1.0)(use-sync-external-store@1.5.0(react@19.1.0)) devDependencies: '@biomejs/biome': specifier: 1.9.4 @@ -1652,6 +1655,9 @@ packages: resolution: {integrity: sha512-jYZ6ZtfWjzBdh8H/0CIFfCBHaFL75k+KMzaM177hrWWm2TWL39YMYaJgB74uK/niRc866NMlH9B8uCvIo284WQ==} engines: {node: '>=16.9.0'} + immer@10.1.1: + resolution: {integrity: sha512-s2MPrmjovJcoMaHtx6K11Ra7oD05NT97w1IC5zpMkT6Atjr7H8LjaDd81iIxUYpMKSRRNMJE703M1Fhr/TctHw==} + inherits@2.0.4: resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} @@ -4202,6 +4208,8 @@ snapshots: hono@4.8.3: {} + immer@10.1.1: {} + inherits@2.0.4: {} interpret@2.2.0: {} @@ -5320,8 +5328,9 @@ snapshots: zod@3.25.67: {} - zustand@5.0.6(@types/react@19.1.8)(react@19.1.0)(use-sync-external-store@1.5.0(react@19.1.0)): + zustand@5.0.6(@types/react@19.1.8)(immer@10.1.1)(react@19.1.0)(use-sync-external-store@1.5.0(react@19.1.0)): optionalDependencies: '@types/react': 19.1.8 + immer: 10.1.1 react: 19.1.0 use-sync-external-store: 1.5.0(react@19.1.0) diff --git a/state.ts b/state.ts index 9d96a17..9ccdf0c 100644 --- a/state.ts +++ b/state.ts @@ -1,6 +1,7 @@ import { create } from "zustand"; import type { OtherParameters, Store } from "./types.js"; import type { ConversationsId } from "./database/generated/public/Conversations.js"; +import { immer } from "zustand/middleware/immer"; export const defaultSystemPrompt = `You are a helpful assistant that answers questions based on the provided context. If you don't know the answer, just say that you don't know, don't try to make up an answer.`; export const defaultParameters = { @@ -8,31 +9,65 @@ export const defaultParameters = { max_tokens: 100, } as OtherParameters; -export const useStore = create()((set) => ({ - conversationId: 0 as ConversationsId, - conversationTitle: "", - conversations: [], - messages: [], - message: "", - systemPrompt: defaultSystemPrompt, - parameters: defaultParameters, - loading: false, - setConversationId: (conversationId) => set({ conversationId }), - setConversationTitle: (conversationTitle) => set({ conversationTitle }), - setConversations: (conversations) => set({ conversations }), - addConversation: (conversation) => - set((state) => ({ - conversations: [...state.conversations, conversation], - })), - removeConversation: (conversationId) => - set((state) => ({ - conversations: state.conversations.filter( - (conversation) => conversation.id !== conversationId, - ), - })), - setMessages: (messages) => set({ messages }), - setMessage: (message) => set({ message }), - setSystemPrompt: (systemPrompt) => set({ systemPrompt }), - setParameters: (parameters) => set({ parameters }), - setLoading: (loading) => set({ loading }), -})); +export const useStore = create()( + immer((set, get) => ({ + selectedConversationId: 0 as ConversationsId, + conversations: [], + messages: [], + message: "", + systemPrompt: defaultSystemPrompt, + parameters: defaultParameters, + loading: false, + setConversationId: (conversationId) => + set((stateDraft) => { + stateDraft.selectedConversationId = conversationId; + }), + setConversationTitle: (conversationTitle) => + set((stateDraft) => { + const conversation = stateDraft.conversations.find( + (c) => c.id === stateDraft.selectedConversationId, + ); + if (conversation) { + conversation.title = conversationTitle; + } + }), + setConversations: (conversations) => + set((stateDraft) => { + stateDraft.conversations = conversations; + }), + addConversation: (conversation) => + set((stateDraft) => { + stateDraft.conversations.push(conversation); + }), + removeConversation: (conversationId) => + set((stateDraft) => { + stateDraft.conversations.splice( + stateDraft.conversations.findIndex( + (conversation) => conversation.id === conversationId, + ), + 1, + ); + }), + setMessages: (messages) => + set((stateDraft) => { + //@ts-ignore + stateDraft.messages = messages; + }), + setMessage: (message) => + set((stateDraft) => { + stateDraft.message = message; + }), + setSystemPrompt: (systemPrompt) => + set((stateDraft) => { + stateDraft.systemPrompt = systemPrompt; + }), + setParameters: (parameters) => + set((stateDraft) => { + stateDraft.parameters = parameters; + }), + setLoading: (loading) => + set((stateDraft) => { + stateDraft.loading = loading; + }), + })), +); diff --git a/types.ts b/types.ts index a25ce31..ff5ebb6 100644 --- a/types.ts +++ b/types.ts @@ -15,8 +15,7 @@ export type ConversationUI = Conversations & {}; export type Store = { /** This is a string because Milvus sends it as a string, and the value * overflows the JS integer anyway. */ - conversationId: ConversationsId; - conversationTitle: string; + selectedConversationId: ConversationsId; conversations: Array; messages: Array; message: string;