From 0207e4fc4710fab23358a2d0b426562febf1046b Mon Sep 17 00:00:00 2001 From: Avraham Sakal Date: Sun, 21 Sep 2025 19:12:15 -0400 Subject: [PATCH] authz/authn on all trpc procedures --- pages/chat/@id/+data.ts | 3 +- server/authjs-handler.ts | 2 + server/trpc/conversations.ts | 84 ++++++++++++++++++------------------ server/trpc/fact-triggers.ts | 70 ++++++++++++++++++++---------- server/trpc/facts.ts | 58 ++++++++++++++++--------- server/trpc/messages.ts | 47 +++++++++++++++++--- server/trpc/server.ts | 14 ++++++ 7 files changed, 185 insertions(+), 93 deletions(-) diff --git a/pages/chat/@id/+data.ts b/pages/chat/@id/+data.ts index 8553fe2..4f1b419 100644 --- a/pages/chat/@id/+data.ts +++ b/pages/chat/@id/+data.ts @@ -12,13 +12,12 @@ export const data = async (pageContext: PageContextServer) => { openrouter: getOpenrouter( (pageContext.env?.OPENROUTER_API_KEY || env.OPENROUTER_API_KEY) as string ), - // jwt: pageContext., + jwt: pageContext.session?.jwt, dbClient: getDbClient( (pageContext.env?.POSTGRES_CONNECTION_STRING || env.POSTGRES_CONNECTION_STRING) as string ), }); - const [ conversation, // messages, diff --git a/server/authjs-handler.ts b/server/authjs-handler.ts index 5f4d0c3..e174995 100644 --- a/server/authjs-handler.ts +++ b/server/authjs-handler.ts @@ -15,6 +15,7 @@ import type { } from "@universal-middleware/core"; import { env } from "./env.js"; import { getDbClient } from "../database/index.js"; +import { JWT } from "@auth/core/jwt"; const POSTGRES_CONNECTION_STRING = "postgres://neondb_owner:npg_sOVmj8vWq2zG@ep-withered-king-adiz9gpi-pooler.c-2.us-east-1.aws.neon.tech:5432/neondb?sslmode=require&channel_binding=true"; @@ -125,6 +126,7 @@ const authjsConfig = { ...session.user, id: token.id as string, }, + jwt: token, }; }, }, diff --git a/server/trpc/conversations.ts b/server/trpc/conversations.ts index 1b18ee1..99265f6 100644 --- a/server/trpc/conversations.ts +++ b/server/trpc/conversations.ts @@ -1,33 +1,41 @@ +import { TRPCError } from "@trpc/server"; import type { CommittedMessage } from "../../types"; -import { router, publicProcedure, createCallerFactory } from "./server"; +import { router, createCallerFactory, authProcedure } from "./server"; +import { z } from "zod"; + +const authConversationProcedure = authProcedure + .input(z.object({ id: z.string() })) + .use(async ({ input: { id }, ctx: { dbClient, jwt }, next }) => { + const rows = await dbClient + .selectFrom("conversations") + .selectAll() + .where("id", "=", id) + .execute(); + if (rows[0].userId !== jwt.id) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + return next({ + ctx: { + conversationRow: rows[0], + }, + }); + }); export const conversations = router({ - fetchAll: publicProcedure.query(async ({ ctx: { dbClient, jwt } }) => { - const userId = jwt?.id as string | null; - if (!userId) return []; + fetchAll: authProcedure.query(async ({ ctx: { dbClient, jwt } }) => { const rows = await dbClient .selectFrom("conversations") - .where("userId", "=", userId) + .where("userId", "=", jwt.id as string) .selectAll() .execute(); return rows; }), - fetchOne: publicProcedure - .input((x) => x as { id: string }) - .query(async ({ input: { id }, ctx: { dbClient, jwt } }) => { - const userId = jwt?.id as string | null; - if (!userId) return null; - const row = await dbClient - .selectFrom("conversations") - .selectAll() - .where("id", "=", id) - .where("userId", "=", userId) - .execute(); - return row[0]; - }), - start: publicProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => { - const userId = jwt?.id as string | null; - if (!userId) return null; + fetchOne: authConversationProcedure.query( + async ({ ctx: { conversationRow } }) => { + return conversationRow; + } + ), + start: authProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => { const insertedRows = await dbClient .insertInto("conversations") .values({ @@ -38,42 +46,34 @@ export const conversations = router({ .execute(); return insertedRows[0]; }), - deleteOne: publicProcedure - .input((x) => x as { id: string }) - .mutation(async ({ input: { id }, ctx: { dbClient, jwt } }) => { - const userId = jwt?.id as string | null; - if (!userId) return { ok: false }; + deleteOne: authConversationProcedure.mutation( + async ({ input: { id }, ctx: { dbClient, jwt } }) => { await dbClient .deleteFrom("conversations") .where("id", "=", id) - .where("userId", "=", userId) + .where("userId", "=", jwt.id as string) .execute(); return { ok: true }; - }), - updateTitle: publicProcedure + } + ), + updateTitle: authConversationProcedure .input( - (x) => - x as { - id: string; - title: string; - } + z.object({ + title: z.string(), + }) ) .mutation(async ({ input: { id, title }, ctx: { dbClient, jwt } }) => { - const userId = jwt?.id as string | null; - if (!userId) return { ok: false }; await dbClient .updateTable("conversations") .set({ title }) .where("id", "=", id) - .where("userId", "=", userId) + .where("userId", "=", jwt.id as string) .execute(); return { ok: true }; }), - fetchMessages: publicProcedure - .input((x) => x as { conversationId: string }) + fetchMessages: authProcedure + .input(z.object({ conversationId: z.string() })) .query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => { - const userId = jwt?.id as string | null; - if (!userId) return []; const rows = await dbClient .selectFrom("messages") .innerJoin( @@ -83,7 +83,7 @@ export const conversations = router({ ) .selectAll("messages") .where("conversationId", "=", conversationId) - .where("conversations.userId", "=", userId) + .where("conversations.userId", "=", jwt.id as string) .execute(); return rows as Array; }), diff --git a/server/trpc/fact-triggers.ts b/server/trpc/fact-triggers.ts index 8dd0b8a..7a596c0 100644 --- a/server/trpc/fact-triggers.ts +++ b/server/trpc/fact-triggers.ts @@ -1,8 +1,15 @@ -import { router, publicProcedure, createCallerFactory } from "./server.js"; +import { + router, + publicProcedure, + createCallerFactory, + authProcedure, +} from "./server.js"; import type { DraftMessage } from "../../types.js"; import { MODEL_NAME } from "../provider.js"; import { generateObject, generateText, jsonSchema } from "ai"; import type { Fact } from "@database/common.js"; +import { TRPCError } from "@trpc/server"; +import { z } from "zod"; const factTriggersSystemPrompt = ({ previousRunningSummary, @@ -53,52 +60,71 @@ ${factContent} Generate a list of situations in which the fact is useful.`; +const authFactTriggerProcedure = authProcedure + .input(z.object({ factTriggerId: z.string() })) + .use(async ({ input, ctx: { dbClient, jwt }, next }) => { + const factTriggerRows = await dbClient + .selectFrom("fact_triggers") + .innerJoin("facts", "facts.id", "fact_triggers.sourceFactId") + .innerJoin("messages", "messages.id", "facts.sourceMessageId") + .innerJoin("conversations", "conversations.id", "messages.conversationId") + .where("fact_triggers.id", "=", input.factTriggerId) + .where("conversations.userId", "=", jwt.id as string) + .execute(); + if (!factTriggerRows.length) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + return await next(); + }); + export const factTriggers = router({ - fetchByFactId: publicProcedure + fetchByFactId: authProcedure .input((x) => x as { factId: string }) - .query(async ({ input: { factId }, ctx: { dbClient } }) => { + .query(async ({ input: { factId }, ctx: { dbClient, jwt } }) => { const rows = await dbClient .selectFrom("fact_triggers") .innerJoin("facts", "facts.id", "fact_triggers.sourceFactId") + .innerJoin("messages", "messages.id", "facts.sourceMessageId") + .innerJoin( + "conversations", + "conversations.id", + "messages.conversationId" + ) .selectAll("fact_triggers") .where("sourceFactId", "=", factId) + .where("conversations.userId", "=", jwt.id as string) .execute(); return rows; }), - fetchByConversationId: publicProcedure + fetchByConversationId: authProcedure .input((x) => x as { conversationId: string }) - .query(async ({ input: { conversationId }, ctx: { dbClient } }) => { + .query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => { const rows = await dbClient .selectFrom("fact_triggers") .innerJoin("facts", "facts.id", "fact_triggers.sourceFactId") .innerJoin("messages", "messages.id", "facts.sourceMessageId") + .innerJoin( + "conversations", + "conversations.id", + "messages.conversationId" + ) .selectAll("fact_triggers") .where("messages.conversationId", "=", conversationId) + .where("conversations.userId", "=", jwt.id as string) .execute(); return rows; }), - deleteOne: publicProcedure - .input( - (x) => - x as { - factTriggerId: string; - } - ) - .mutation(async ({ input: { factTriggerId }, ctx: { dbClient } }) => { + deleteOne: authFactTriggerProcedure.mutation( + async ({ input: { factTriggerId }, ctx: { dbClient, jwt } }) => { await dbClient .deleteFrom("fact_triggers") .where("id", "=", factTriggerId) .execute(); return { ok: true }; - }), - update: publicProcedure - .input( - (x) => - x as { - factTriggerId: string; - content: string; - } - ) + } + ), + update: authFactTriggerProcedure + .input(z.object({ content: z.string() })) .mutation( async ({ input: { factTriggerId, content }, ctx: { dbClient } }) => { await dbClient diff --git a/server/trpc/facts.ts b/server/trpc/facts.ts index 8fa9283..0bd16fc 100644 --- a/server/trpc/facts.ts +++ b/server/trpc/facts.ts @@ -1,7 +1,9 @@ -import { router, publicProcedure, createCallerFactory } from "./server.js"; +import { router, createCallerFactory, authProcedure } from "./server.js"; import type { DraftMessage } from "../../types.js"; import { MODEL_NAME, openrouter } from "../provider.js"; import { generateObject, generateText, jsonSchema } from "ai"; +import { TRPCError } from "@trpc/server"; +import { z } from "zod"; const factsFromNewMessagesSystemPrompt = ({ previousRunningSummary, @@ -48,9 +50,36 @@ const factsFromNewMessagesUserPrompt = ({ Extract new facts from these messages.`; +const authFactProcedure = authProcedure + .input(z.object({ factId: z.string() })) + .use(async ({ input, ctx: { dbClient, jwt }, next }) => { + const factRows = await dbClient + .selectFrom("facts") + .innerJoin("messages", "messages.id", "facts.sourceMessageId") + .innerJoin("conversations", "conversations.id", "messages.conversationId") + .where("facts.id", "=", input.factId) + .where("conversations.userId", "=", jwt.id as string) + .execute(); + if (!factRows.length) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + return await next(); + }); + export const facts = router({ - fetchByConversationId: publicProcedure + fetchByConversationId: authProcedure .input((x) => x as { conversationId: string }) + .use(async ({ input, ctx: { dbClient, jwt }, next }) => { + const conversationRows = await dbClient + .selectFrom("conversations") + .where("id", "=", input.conversationId) + .where("userId", "=", jwt.id as string) + .execute(); + if (!conversationRows.length) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + return await next(); + }) .query(async ({ input: { conversationId }, ctx: { dbClient } }) => { const rows = await dbClient .selectFrom("facts") @@ -60,25 +89,14 @@ export const facts = router({ .execute(); return rows; }), - deleteOne: publicProcedure - .input( - (x) => - x as { - factId: string; - } - ) - .mutation(async ({ input: { factId }, ctx: { dbClient } }) => { + deleteOne: authFactProcedure.mutation( + async ({ input: { factId }, ctx: { dbClient } }) => { await dbClient.deleteFrom("facts").where("id", "=", factId).execute(); return { ok: true }; - }), - update: publicProcedure - .input( - (x) => - x as { - factId: string; - content: string; - } - ) + } + ), + update: authFactProcedure + .input(z.object({ content: z.string() })) .mutation(async ({ input: { factId, content }, ctx: { dbClient } }) => { await dbClient .updateTable("facts") @@ -87,7 +105,7 @@ export const facts = router({ .execute(); return { ok: true }; }), - extractFromNewMessages: publicProcedure + extractFromNewMessages: authProcedure .input( (x) => x as { diff --git a/server/trpc/messages.ts b/server/trpc/messages.ts index b89ece7..3425c4b 100644 --- a/server/trpc/messages.ts +++ b/server/trpc/messages.ts @@ -1,7 +1,14 @@ -import { router, publicProcedure, createCallerFactory } from "./server"; +import { + router, + publicProcedure, + createCallerFactory, + authProcedure, +} from "./server"; import { MODEL_NAME } from "../provider.js"; import { generateObject, generateText, jsonSchema } from "ai"; import type { CommittedMessage, DraftMessage } from "../../types.js"; +import { TRPCError } from "@trpc/server"; +import { z } from "zod"; const runningSummarySystemPrompt = ({ previousRunningSummary, @@ -43,9 +50,35 @@ ${mainResponseContent} Generate a new running summary of the conversation.`; +const authMessageProcedure = authProcedure + .input(z.object({ id: z.string() })) + .use(async ({ input, ctx: { dbClient, jwt }, next }) => { + const messageRows = await dbClient + .selectFrom("messages") + .innerJoin("conversations", "conversations.id", "messages.conversationId") + .where("messages.id", "=", input.id) + .where("conversations.userId", "=", jwt.id as string) + .execute(); + if (!messageRows.length) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + return await next(); + }); + export const messages = router({ - fetchByConversationId: publicProcedure + fetchByConversationId: authProcedure .input((x) => x as { conversationId: string }) + .use(async ({ input, ctx: { dbClient, jwt }, next }) => { + const conversationRows = await dbClient + .selectFrom("conversations") + .where("id", "=", input.conversationId) + .where("userId", "=", jwt.id as string) + .execute(); + if (!conversationRows.length) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + return await next(); + }) .query(async ({ input: { conversationId }, ctx: { dbClient } }) => { const rows = (await dbClient .selectFrom("messages") @@ -54,13 +87,13 @@ export const messages = router({ .execute()) as Array; return rows; }), - deleteOne: publicProcedure - .input((x) => x as { id: string }) - .mutation(async ({ input: { id }, ctx: { dbClient } }) => { + deleteOne: authMessageProcedure.mutation( + async ({ input: { id }, ctx: { dbClient } }) => { await dbClient.deleteFrom("messages").where("id", "=", id).execute(); return { success: true }; - }), - generateRunningSummary: publicProcedure + } + ), + generateRunningSummary: authProcedure .input( (x) => x as { diff --git a/server/trpc/server.ts b/server/trpc/server.ts index e7933f4..de2782b 100644 --- a/server/trpc/server.ts +++ b/server/trpc/server.ts @@ -36,6 +36,20 @@ const t = initTRPC */ export const router = t.router; export const publicProcedure = t.procedure; +export const authProcedure = publicProcedure.use( + async ({ ctx: { jwt }, next }) => { + if (!jwt) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + if (!jwt.id) { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + jwt.email; + return await next({ + ctx: { jwt }, + }); + } +); /** * Generate a TRPC-compatible validator function given a Typebox schema.