authz/authn on all trpc procedures

master
Avraham Sakal 3 weeks ago
parent fc70806b10
commit 0207e4fc47

@ -12,13 +12,12 @@ export const data = async (pageContext: PageContextServer) => {
openrouter: getOpenrouter( openrouter: getOpenrouter(
(pageContext.env?.OPENROUTER_API_KEY || env.OPENROUTER_API_KEY) as string (pageContext.env?.OPENROUTER_API_KEY || env.OPENROUTER_API_KEY) as string
), ),
// jwt: pageContext., jwt: pageContext.session?.jwt,
dbClient: getDbClient( dbClient: getDbClient(
(pageContext.env?.POSTGRES_CONNECTION_STRING || (pageContext.env?.POSTGRES_CONNECTION_STRING ||
env.POSTGRES_CONNECTION_STRING) as string env.POSTGRES_CONNECTION_STRING) as string
), ),
}); });
const [ const [
conversation, conversation,
// messages, // messages,

@ -15,6 +15,7 @@ import type {
} from "@universal-middleware/core"; } from "@universal-middleware/core";
import { env } from "./env.js"; import { env } from "./env.js";
import { getDbClient } from "../database/index.js"; import { getDbClient } from "../database/index.js";
import { JWT } from "@auth/core/jwt";
const POSTGRES_CONNECTION_STRING = const POSTGRES_CONNECTION_STRING =
"postgres://neondb_owner:npg_sOVmj8vWq2zG@ep-withered-king-adiz9gpi-pooler.c-2.us-east-1.aws.neon.tech:5432/neondb?sslmode=require&channel_binding=true"; "postgres://neondb_owner:npg_sOVmj8vWq2zG@ep-withered-king-adiz9gpi-pooler.c-2.us-east-1.aws.neon.tech:5432/neondb?sslmode=require&channel_binding=true";
@ -125,6 +126,7 @@ const authjsConfig = {
...session.user, ...session.user,
id: token.id as string, id: token.id as string,
}, },
jwt: token,
}; };
}, },
}, },

@ -1,33 +1,41 @@
import { TRPCError } from "@trpc/server";
import type { CommittedMessage } from "../../types"; import type { CommittedMessage } from "../../types";
import { router, publicProcedure, createCallerFactory } from "./server"; import { router, createCallerFactory, authProcedure } from "./server";
import { z } from "zod";
export const conversations = router({ const authConversationProcedure = authProcedure
fetchAll: publicProcedure.query(async ({ ctx: { dbClient, jwt } }) => { .input(z.object({ id: z.string() }))
const userId = jwt?.id as string | null; .use(async ({ input: { id }, ctx: { dbClient, jwt }, next }) => {
if (!userId) return [];
const rows = await dbClient const rows = await dbClient
.selectFrom("conversations") .selectFrom("conversations")
.where("userId", "=", userId)
.selectAll() .selectAll()
.where("id", "=", id)
.execute(); .execute();
return rows; if (rows[0].userId !== jwt.id) {
}), throw new TRPCError({ code: "UNAUTHORIZED" });
fetchOne: publicProcedure }
.input((x) => x as { id: string }) return next({
.query(async ({ input: { id }, ctx: { dbClient, jwt } }) => { ctx: {
const userId = jwt?.id as string | null; conversationRow: rows[0],
if (!userId) return null; },
const row = await dbClient });
});
export const conversations = router({
fetchAll: authProcedure.query(async ({ ctx: { dbClient, jwt } }) => {
const rows = await dbClient
.selectFrom("conversations") .selectFrom("conversations")
.where("userId", "=", jwt.id as string)
.selectAll() .selectAll()
.where("id", "=", id)
.where("userId", "=", userId)
.execute(); .execute();
return row[0]; return rows;
}), }),
start: publicProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => { fetchOne: authConversationProcedure.query(
const userId = jwt?.id as string | null; async ({ ctx: { conversationRow } }) => {
if (!userId) return null; return conversationRow;
}
),
start: authProcedure.mutation(async ({ ctx: { dbClient, jwt } }) => {
const insertedRows = await dbClient const insertedRows = await dbClient
.insertInto("conversations") .insertInto("conversations")
.values({ .values({
@ -38,42 +46,34 @@ export const conversations = router({
.execute(); .execute();
return insertedRows[0]; return insertedRows[0];
}), }),
deleteOne: publicProcedure deleteOne: authConversationProcedure.mutation(
.input((x) => x as { id: string }) async ({ input: { id }, ctx: { dbClient, jwt } }) => {
.mutation(async ({ input: { id }, ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return { ok: false };
await dbClient await dbClient
.deleteFrom("conversations") .deleteFrom("conversations")
.where("id", "=", id) .where("id", "=", id)
.where("userId", "=", userId) .where("userId", "=", jwt.id as string)
.execute(); .execute();
return { ok: true }; return { ok: true };
}),
updateTitle: publicProcedure
.input(
(x) =>
x as {
id: string;
title: string;
} }
),
updateTitle: authConversationProcedure
.input(
z.object({
title: z.string(),
})
) )
.mutation(async ({ input: { id, title }, ctx: { dbClient, jwt } }) => { .mutation(async ({ input: { id, title }, ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return { ok: false };
await dbClient await dbClient
.updateTable("conversations") .updateTable("conversations")
.set({ title }) .set({ title })
.where("id", "=", id) .where("id", "=", id)
.where("userId", "=", userId) .where("userId", "=", jwt.id as string)
.execute(); .execute();
return { ok: true }; return { ok: true };
}), }),
fetchMessages: publicProcedure fetchMessages: authProcedure
.input((x) => x as { conversationId: string }) .input(z.object({ conversationId: z.string() }))
.query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => { .query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => {
const userId = jwt?.id as string | null;
if (!userId) return [];
const rows = await dbClient const rows = await dbClient
.selectFrom("messages") .selectFrom("messages")
.innerJoin( .innerJoin(
@ -83,7 +83,7 @@ export const conversations = router({
) )
.selectAll("messages") .selectAll("messages")
.where("conversationId", "=", conversationId) .where("conversationId", "=", conversationId)
.where("conversations.userId", "=", userId) .where("conversations.userId", "=", jwt.id as string)
.execute(); .execute();
return rows as Array<CommittedMessage>; return rows as Array<CommittedMessage>;
}), }),

@ -1,8 +1,15 @@
import { router, publicProcedure, createCallerFactory } from "./server.js"; import {
router,
publicProcedure,
createCallerFactory,
authProcedure,
} from "./server.js";
import type { DraftMessage } from "../../types.js"; import type { DraftMessage } from "../../types.js";
import { MODEL_NAME } from "../provider.js"; import { MODEL_NAME } from "../provider.js";
import { generateObject, generateText, jsonSchema } from "ai"; import { generateObject, generateText, jsonSchema } from "ai";
import type { Fact } from "@database/common.js"; import type { Fact } from "@database/common.js";
import { TRPCError } from "@trpc/server";
import { z } from "zod";
const factTriggersSystemPrompt = ({ const factTriggersSystemPrompt = ({
previousRunningSummary, previousRunningSummary,
@ -53,52 +60,71 @@ ${factContent}
Generate a list of situations in which the fact is useful.`; Generate a list of situations in which the fact is useful.`;
const authFactTriggerProcedure = authProcedure
.input(z.object({ factTriggerId: z.string() }))
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const factTriggerRows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin("conversations", "conversations.id", "messages.conversationId")
.where("fact_triggers.id", "=", input.factTriggerId)
.where("conversations.userId", "=", jwt.id as string)
.execute();
if (!factTriggerRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
});
export const factTriggers = router({ export const factTriggers = router({
fetchByFactId: publicProcedure fetchByFactId: authProcedure
.input((x) => x as { factId: string }) .input((x) => x as { factId: string })
.query(async ({ input: { factId }, ctx: { dbClient } }) => { .query(async ({ input: { factId }, ctx: { dbClient, jwt } }) => {
const rows = await dbClient const rows = await dbClient
.selectFrom("fact_triggers") .selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId") .innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin(
"conversations",
"conversations.id",
"messages.conversationId"
)
.selectAll("fact_triggers") .selectAll("fact_triggers")
.where("sourceFactId", "=", factId) .where("sourceFactId", "=", factId)
.where("conversations.userId", "=", jwt.id as string)
.execute(); .execute();
return rows; return rows;
}), }),
fetchByConversationId: publicProcedure fetchByConversationId: authProcedure
.input((x) => x as { conversationId: string }) .input((x) => x as { conversationId: string })
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => { .query(async ({ input: { conversationId }, ctx: { dbClient, jwt } }) => {
const rows = await dbClient const rows = await dbClient
.selectFrom("fact_triggers") .selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId") .innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId") .innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin(
"conversations",
"conversations.id",
"messages.conversationId"
)
.selectAll("fact_triggers") .selectAll("fact_triggers")
.where("messages.conversationId", "=", conversationId) .where("messages.conversationId", "=", conversationId)
.where("conversations.userId", "=", jwt.id as string)
.execute(); .execute();
return rows; return rows;
}), }),
deleteOne: publicProcedure deleteOne: authFactTriggerProcedure.mutation(
.input( async ({ input: { factTriggerId }, ctx: { dbClient, jwt } }) => {
(x) =>
x as {
factTriggerId: string;
}
)
.mutation(async ({ input: { factTriggerId }, ctx: { dbClient } }) => {
await dbClient await dbClient
.deleteFrom("fact_triggers") .deleteFrom("fact_triggers")
.where("id", "=", factTriggerId) .where("id", "=", factTriggerId)
.execute(); .execute();
return { ok: true }; return { ok: true };
}),
update: publicProcedure
.input(
(x) =>
x as {
factTriggerId: string;
content: string;
} }
) ),
update: authFactTriggerProcedure
.input(z.object({ content: z.string() }))
.mutation( .mutation(
async ({ input: { factTriggerId, content }, ctx: { dbClient } }) => { async ({ input: { factTriggerId, content }, ctx: { dbClient } }) => {
await dbClient await dbClient

@ -1,7 +1,9 @@
import { router, publicProcedure, createCallerFactory } from "./server.js"; import { router, createCallerFactory, authProcedure } from "./server.js";
import type { DraftMessage } from "../../types.js"; import type { DraftMessage } from "../../types.js";
import { MODEL_NAME, openrouter } from "../provider.js"; import { MODEL_NAME, openrouter } from "../provider.js";
import { generateObject, generateText, jsonSchema } from "ai"; import { generateObject, generateText, jsonSchema } from "ai";
import { TRPCError } from "@trpc/server";
import { z } from "zod";
const factsFromNewMessagesSystemPrompt = ({ const factsFromNewMessagesSystemPrompt = ({
previousRunningSummary, previousRunningSummary,
@ -48,9 +50,36 @@ const factsFromNewMessagesUserPrompt = ({
Extract new facts from these messages.`; Extract new facts from these messages.`;
const authFactProcedure = authProcedure
.input(z.object({ factId: z.string() }))
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const factRows = await dbClient
.selectFrom("facts")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.innerJoin("conversations", "conversations.id", "messages.conversationId")
.where("facts.id", "=", input.factId)
.where("conversations.userId", "=", jwt.id as string)
.execute();
if (!factRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
});
export const facts = router({ export const facts = router({
fetchByConversationId: publicProcedure fetchByConversationId: authProcedure
.input((x) => x as { conversationId: string }) .input((x) => x as { conversationId: string })
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const conversationRows = await dbClient
.selectFrom("conversations")
.where("id", "=", input.conversationId)
.where("userId", "=", jwt.id as string)
.execute();
if (!conversationRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
})
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => { .query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
const rows = await dbClient const rows = await dbClient
.selectFrom("facts") .selectFrom("facts")
@ -60,25 +89,14 @@ export const facts = router({
.execute(); .execute();
return rows; return rows;
}), }),
deleteOne: publicProcedure deleteOne: authFactProcedure.mutation(
.input( async ({ input: { factId }, ctx: { dbClient } }) => {
(x) =>
x as {
factId: string;
}
)
.mutation(async ({ input: { factId }, ctx: { dbClient } }) => {
await dbClient.deleteFrom("facts").where("id", "=", factId).execute(); await dbClient.deleteFrom("facts").where("id", "=", factId).execute();
return { ok: true }; return { ok: true };
}),
update: publicProcedure
.input(
(x) =>
x as {
factId: string;
content: string;
} }
) ),
update: authFactProcedure
.input(z.object({ content: z.string() }))
.mutation(async ({ input: { factId, content }, ctx: { dbClient } }) => { .mutation(async ({ input: { factId, content }, ctx: { dbClient } }) => {
await dbClient await dbClient
.updateTable("facts") .updateTable("facts")
@ -87,7 +105,7 @@ export const facts = router({
.execute(); .execute();
return { ok: true }; return { ok: true };
}), }),
extractFromNewMessages: publicProcedure extractFromNewMessages: authProcedure
.input( .input(
(x) => (x) =>
x as { x as {

@ -1,7 +1,14 @@
import { router, publicProcedure, createCallerFactory } from "./server"; import {
router,
publicProcedure,
createCallerFactory,
authProcedure,
} from "./server";
import { MODEL_NAME } from "../provider.js"; import { MODEL_NAME } from "../provider.js";
import { generateObject, generateText, jsonSchema } from "ai"; import { generateObject, generateText, jsonSchema } from "ai";
import type { CommittedMessage, DraftMessage } from "../../types.js"; import type { CommittedMessage, DraftMessage } from "../../types.js";
import { TRPCError } from "@trpc/server";
import { z } from "zod";
const runningSummarySystemPrompt = ({ const runningSummarySystemPrompt = ({
previousRunningSummary, previousRunningSummary,
@ -43,9 +50,35 @@ ${mainResponseContent}
Generate a new running summary of the conversation.`; Generate a new running summary of the conversation.`;
const authMessageProcedure = authProcedure
.input(z.object({ id: z.string() }))
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const messageRows = await dbClient
.selectFrom("messages")
.innerJoin("conversations", "conversations.id", "messages.conversationId")
.where("messages.id", "=", input.id)
.where("conversations.userId", "=", jwt.id as string)
.execute();
if (!messageRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
});
export const messages = router({ export const messages = router({
fetchByConversationId: publicProcedure fetchByConversationId: authProcedure
.input((x) => x as { conversationId: string }) .input((x) => x as { conversationId: string })
.use(async ({ input, ctx: { dbClient, jwt }, next }) => {
const conversationRows = await dbClient
.selectFrom("conversations")
.where("id", "=", input.conversationId)
.where("userId", "=", jwt.id as string)
.execute();
if (!conversationRows.length) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
return await next();
})
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => { .query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
const rows = (await dbClient const rows = (await dbClient
.selectFrom("messages") .selectFrom("messages")
@ -54,13 +87,13 @@ export const messages = router({
.execute()) as Array<CommittedMessage>; .execute()) as Array<CommittedMessage>;
return rows; return rows;
}), }),
deleteOne: publicProcedure deleteOne: authMessageProcedure.mutation(
.input((x) => x as { id: string }) async ({ input: { id }, ctx: { dbClient } }) => {
.mutation(async ({ input: { id }, ctx: { dbClient } }) => {
await dbClient.deleteFrom("messages").where("id", "=", id).execute(); await dbClient.deleteFrom("messages").where("id", "=", id).execute();
return { success: true }; return { success: true };
}), }
generateRunningSummary: publicProcedure ),
generateRunningSummary: authProcedure
.input( .input(
(x) => (x) =>
x as { x as {

@ -36,6 +36,20 @@ const t = initTRPC
*/ */
export const router = t.router; export const router = t.router;
export const publicProcedure = t.procedure; export const publicProcedure = t.procedure;
export const authProcedure = publicProcedure.use(
async ({ ctx: { jwt }, next }) => {
if (!jwt) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
if (!jwt.id) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
jwt.email;
return await next({
ctx: { jwt },
});
}
);
/** /**
* Generate a TRPC-compatible validator function given a Typebox schema. * Generate a TRPC-compatible validator function given a Typebox schema.

Loading…
Cancel
Save