drop `db` interface, use kysely instead for more more flexibility (at the cost of more coupling)

master
Avraham Sakal 3 weeks ago
parent 21931a20cb
commit 360bfc6df3

@ -1,2 +1,2 @@
// export { db } from "./lowdb";
export { getDb, getDbClient } from "./postgres";
export { getDbClient } from "./postgres";

@ -36,292 +36,3 @@ export function getDbClient(POSTGRES_CONNECTION_STRING: string) {
return dbClient;
}
export function getDb(POSTGRES_CONNECTION_STRING: string) {
const dbClient = getDbClient(POSTGRES_CONNECTION_STRING);
const conversations: ConversationEntity = {
construct: (conversation) => conversation,
create: async (conversation) => {
const insertedRows = await dbClient
.insertInto("conversations")
.values(conversation)
.returningAll()
.execute();
return insertedRows[0];
},
createMany: async (conversations) => {
const insertedRows = await dbClient
.insertInto("conversations")
.values(conversations)
.returningAll()
.execute();
return insertedRows;
},
findAll: async (user) => {
const userId = user?.userId;
let query = await dbClient.selectFrom("conversations");
if (userId) {
query = query.where("userId", "=", userId);
}
const rows = query.selectAll().execute();
return rows;
},
findById: async (id) => {
const row = await dbClient
.selectFrom("conversations")
.selectAll()
.where("id", "=", id)
.execute();
return row[0];
},
update: async (id, data) => {
await dbClient
.updateTable("conversations")
.set(data)
.where("id", "=", id)
.execute();
},
delete: async (id) => {
await dbClient.deleteFrom("conversations").where("id", "=", id).execute();
},
fetchMessages: async (conversationId) => {
const rows = await dbClient
.selectFrom("messages")
.selectAll()
.where("conversationId", "=", conversationId)
.execute();
return rows as Array<CommittedMessage>;
},
};
const facts: FactEntity = {
construct: (fact) => fact,
create: async (fact) => {
const insertedRows = await dbClient
.insertInto("facts")
.values(fact)
.returningAll()
.execute();
return insertedRows[0];
},
createMany: async (facts) => {
const insertedRows = await dbClient
.insertInto("facts")
.values(facts)
.returningAll()
.execute();
return insertedRows;
},
findAll: async () => {
const rows = await dbClient.selectFrom("facts").selectAll().execute();
return rows;
},
findById: async (id) => {
const row = await dbClient
.selectFrom("facts")
.selectAll()
.where("id", "=", id)
.execute();
return row[0];
},
update: async (id, data) => {
await dbClient
.updateTable("facts")
.set(data)
.where("id", "=", id)
.execute();
},
delete: async (id) => {
await dbClient.deleteFrom("facts").where("id", "=", id).execute();
},
findByConversationId: async (conversationId) => {
const rows = await dbClient
.selectFrom("facts")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.selectAll("facts")
.where("conversationId", "=", conversationId)
.execute();
return rows;
},
};
const factTriggers: FactTriggerEntity = {
construct: (factTrigger) => factTrigger,
create: async (factTrigger) => {
const insertedRows = await dbClient
.insertInto("fact_triggers")
.values(factTrigger)
.returningAll()
.execute();
return insertedRows[0];
},
createMany: async (factTriggers) => {
const insertedRows = await dbClient
.insertInto("fact_triggers")
.values(factTriggers)
.returningAll()
.execute();
return insertedRows;
},
findAll: async () => {
const rows = await dbClient
.selectFrom("fact_triggers")
.selectAll()
.execute();
return rows;
},
findById: async (id) => {
const row = await dbClient
.selectFrom("fact_triggers")
.selectAll()
.where("id", "=", id)
.execute();
return row[0];
},
update: async (id, data) => {
await dbClient
.updateTable("fact_triggers")
.set(data)
.where("id", "=", id)
.execute();
},
delete: async (id) => {
await dbClient.deleteFrom("fact_triggers").where("id", "=", id).execute();
},
findByFactId: async (factId) => {
const rows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.selectAll("fact_triggers")
.where("sourceFactId", "=", factId)
.execute();
return rows;
},
findByConversationId: async (conversationId) => {
const rows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.selectAll("fact_triggers")
.where("messages.conversationId", "=", conversationId)
.execute();
return rows;
},
};
const messages: MessageEntity = {
construct: (message) => message,
create: async (message) => {
const insertedRows = await dbClient
.insertInto("messages")
.values({ ...message, parts: JSON.stringify(message.parts) })
.returningAll()
.execute();
return insertedRows[0] as CommittedMessage;
},
createMany: async (messages) => {
const insertedRows = await dbClient
.insertInto("messages")
.values(
messages.map((message) => ({
...message,
parts: JSON.stringify(message.parts),
}))
)
.returningAll()
.execute();
return insertedRows as Array<CommittedMessage>;
},
findAll: async () => {
const rows = await dbClient.selectFrom("messages").selectAll().execute();
return rows as Array<CommittedMessage>;
},
findById: async (id) => {
const row = await dbClient
.selectFrom("messages")
.selectAll()
.where("id", "=", id)
.execute();
return row[0] as CommittedMessage;
},
update: async (id, data) => {
await dbClient
.updateTable("messages")
.set(data)
.where("id", "=", id)
.execute();
},
delete: async (id) => {
await dbClient.deleteFrom("messages").where("id", "=", id).execute();
},
findByConversationId: async (conversationId) => {
const rows = (await dbClient
.selectFrom("messages")
.selectAll()
.where("conversationId", "=", conversationId)
.execute()) as Array<CommittedMessage>;
return rows;
},
};
const users: UserEntity = {
construct: (user) => user,
create: async (user) => {
const insertedRows = await dbClient
.insertInto("users")
.values(user)
.returningAll()
.execute();
return insertedRows[0];
},
createMany: async (users) => {
const insertedRows = await dbClient
.insertInto("users")
.values(users)
.returningAll()
.execute();
return insertedRows;
},
findAll: async () => {
const rows = await dbClient.selectFrom("users").selectAll().execute();
return rows;
},
findById: async (id) => {
const row = await dbClient
.selectFrom("users")
.selectAll()
.where("id", "=", id)
.execute();
return row[0];
},
update: async (id, data) => {
await dbClient
.updateTable("users")
.set(data)
.where("id", "=", id)
.execute();
},
delete: async (id) => {
await dbClient.deleteFrom("users").where("id", "=", id).execute();
},
findByEmailAddress: async (emailAddress) => {
const row = await dbClient
.selectFrom("users")
.selectAll()
.where("email", "=", emailAddress)
.executeTakeFirst();
return row;
},
};
const db = {
conversations,
facts,
factTriggers,
messages,
users,
};
return db;
}

@ -1,6 +1,6 @@
import type { PageContextServer } from "vike/types";
import { createCaller } from "../../../server/trpc/chat.js";
import { getDb } from "../../../database/postgres.js";
import { getDbClient } from "../../../database/postgres.js";
import { getOpenrouter } from "../../../server/provider.js";
import { env } from "../../../server/env.js";
@ -9,14 +9,14 @@ export type Data = Awaited<ReturnType<typeof data>>;
export const data = async (pageContext: PageContextServer) => {
const { id } = pageContext.routeParams;
const caller = createCaller({
db: getDb(
(pageContext.env?.POSTGRES_CONNECTION_STRING ||
env.POSTGRES_CONNECTION_STRING) as string
),
openrouter: getOpenrouter(
(pageContext.env?.OPENROUTER_API_KEY || env.OPENROUTER_API_KEY) as string
),
// jwt: pageContext.,
dbClient: getDbClient(
(pageContext.env?.POSTGRES_CONNECTION_STRING ||
env.POSTGRES_CONNECTION_STRING) as string
),
});
const [

@ -14,7 +14,7 @@ import type {
UniversalMiddleware,
} from "@universal-middleware/core";
import { env } from "./env.js";
import { getDb } from "../database/index.js";
import { getDbClient } from "../database/index.js";
const POSTGRES_CONNECTION_STRING =
"postgres://neondb_owner:npg_sOVmj8vWq2zG@ep-withered-king-adiz9gpi-pooler.c-2.us-east-1.aws.neon.tech:5432/neondb?sslmode=require&channel_binding=true";
@ -70,8 +70,12 @@ const authjsConfig = {
callbacks: {
async signIn({ user, account, profile }) {
if (typeof user?.email !== "string") return false;
const db = await getDb(POSTGRES_CONNECTION_STRING);
const userFromDb = await db.users.findByEmailAddress(user.email);
const dbClient = await getDbClient(POSTGRES_CONNECTION_STRING);
const userFromDb = await dbClient
.selectFrom("users")
.selectAll()
.where("email", "=", user.email)
.executeTakeFirst();
if (!userFromDb) {
return false;
}
@ -80,17 +84,34 @@ const authjsConfig = {
},
jwt: async ({ token }) => {
if (typeof token?.email !== "string") return token;
const db = await getDb(POSTGRES_CONNECTION_STRING);
let userFromDb = await db.users.findByEmailAddress(token.email || "");
const dbClient = await getDbClient(POSTGRES_CONNECTION_STRING);
let userFromDb = await dbClient
.selectFrom("users")
.selectAll()
.where("email", "=", token.email || "")
.executeTakeFirst();
if (!userFromDb) {
userFromDb = await db.users.create({
// id: token.id,
userFromDb = (
await dbClient
.insertInto("users")
.values({
email: token.email,
username: token.email,
password: null,
createdAt: null,
lastLogin: null,
});
})
.returningAll()
.execute()
)[0];
// await db.users.create({
// // id: token.id,
// email: token.email,
// username: token.email,
// password: null,
// createdAt: null,
// lastLogin: null,
// });
}
return {
...token,

@ -6,7 +6,7 @@ import {
env as getEnv,
} from "@universal-middleware/core";
import { fetchRequestHandler } from "@trpc/server/adapters/fetch";
import { getDb, getDbClient } from "../database/postgres";
import { getDbClient } from "../database/postgres";
import { getOpenrouter } from "./provider.js";
import { env as processEnv } from "./env.js";
import { getToken } from "@auth/core/jwt";
@ -22,10 +22,6 @@ export const trpcHandler = ((endpoint) => (request, context, runtime) => {
(env.POSTGRES_CONNECTION_STRING ||
processEnv.POSTGRES_CONNECTION_STRING) as string
);
const db = getDb(
(env.POSTGRES_CONNECTION_STRING ||
processEnv.POSTGRES_CONNECTION_STRING) as string
);
const openrouter = getOpenrouter(
(env.OPENROUTER_API_KEY || processEnv.OPENROUTER_API_KEY) as string
);
@ -36,7 +32,6 @@ export const trpcHandler = ((endpoint) => (request, context, runtime) => {
req,
resHeaders,
dbClient,
db,
openrouter,
jwt,
};

@ -69,7 +69,7 @@ export const chat = router({
input: { conversationId, messages, systemPrompt, parameters },
ctx,
}) {
const { db, openrouter, jwt } = ctx;
const { dbClient, openrouter, jwt } = ctx;
const factsCaller = createCallerFacts(ctx);
const messagesCaller = createCallerMessages(ctx);
const factTriggerCaller = createCallerFactTriggers(ctx);
@ -98,14 +98,24 @@ export const chat = router({
} as const;
/** Save the incoming message to the database. */
const insertedUserMessage = await db.messages.create({
const userMessageRowToInsert = {
conversationId,
// content: messages[messages.length - 1].content,
// role: "user" as const,
...messages[messages.length - 1],
index: messages.length - 1,
createdAt: new Date().toISOString(),
});
};
const insertedUserMessageRows = await dbClient
.insertInto("messages")
.values({
...userMessageRowToInsert,
parts: JSON.stringify(userMessageRowToInsert.parts),
})
.returningAll()
.execute();
const insertedUserMessage =
insertedUserMessageRows[0] as CommittedMessage;
// Emit status update
yield {
@ -167,13 +177,17 @@ export const chat = router({
messagesSincePreviousRunningSummary: [],
newMessages: messagesSincePreviousRunningSummary,
});
const insertedFactsFromUserMessage = await db.facts.createMany(
const insertedFactsFromUserMessage = await dbClient
.insertInto("facts")
.values(
factsFromUserMessageResponse.object.facts.map((fact) => ({
userId: jwt?.id as string,
sourceMessageId: insertedUserMessage.id,
content: fact,
}))
);
)
.returningAll()
.execute();
// Emit status update
yield {
@ -191,15 +205,21 @@ export const chat = router({
mainResponseContent: mainResponse.text,
previousRunningSummary,
});
const insertedAssistantMessage = await db.messages.create({
const insertedAssistantMessage = (
await dbClient
.insertInto("messages")
.values({
conversationId,
// content: mainResponse.text,
parts: [{ type: "text", text: mainResponse.text }],
parts: JSON.stringify([{ type: "text", text: mainResponse.text }]),
runningSummary: runningSummaryResponse.text,
role: "assistant" as const,
index: messages.length,
createdAt: new Date().toISOString(),
});
})
.returningAll()
.execute()
)[0];
// Emit status update
yield {
@ -222,14 +242,18 @@ export const chat = router({
],
});
const insertedFactsFromAssistantMessage = await db.facts.createMany(
const insertedFactsFromAssistantMessage = await dbClient
.insertInto("facts")
.values(
factsFromAssistantMessageResponse.object.facts.map((factContent) => ({
userId: jwt?.id as string,
sourceMessageId: insertedAssistantMessage.id,
content: factContent,
createdAt: new Date().toISOString(),
}))
);
)
.returningAll()
.execute();
const insertedFacts = [
...insertedFactsFromUserMessage,
@ -255,6 +279,9 @@ export const chat = router({
fact,
});
const insertedFactTriggers: Array<Omit<FactTrigger, "id">> =
await dbClient
.insertInto("fact_triggers")
.values(
factTriggers.object.factTriggers.map((factTrigger) => ({
sourceFactId: fact.id,
content: factTrigger,
@ -262,8 +289,10 @@ export const chat = router({
priorityMultiplierReason: "",
scopeConversationId: conversationId,
createdAt: new Date().toISOString(),
}));
await db.factTriggers.createMany(insertedFactTriggers);
}))
)
.returningAll()
.execute();
}
// Emit final result

@ -56,13 +56,26 @@ Generate a list of situations in which the fact is useful.`;
export const factTriggers = router({
fetchByFactId: publicProcedure
.input((x) => x as { factId: string })
.query(async ({ input: { factId }, ctx: { db } }) => {
return db.factTriggers.findByFactId(factId);
.query(async ({ input: { factId }, ctx: { dbClient } }) => {
const rows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.selectAll("fact_triggers")
.where("sourceFactId", "=", factId)
.execute();
return rows;
}),
fetchByConversationId: publicProcedure
.input((x) => x as { conversationId: string })
.query(async ({ input: { conversationId }, ctx: { db } }) => {
return await db.factTriggers.findByConversationId(conversationId);
.query(async ({ input: { conversationId }, ctx: { dbClient } }) => {
const rows = await dbClient
.selectFrom("fact_triggers")
.innerJoin("facts", "facts.id", "fact_triggers.sourceFactId")
.innerJoin("messages", "messages.id", "facts.sourceMessageId")
.selectAll("fact_triggers")
.where("messages.conversationId", "=", conversationId)
.execute();
return rows;
}),
deleteOne: publicProcedure
.input(
@ -71,8 +84,11 @@ export const factTriggers = router({
factTriggerId: string;
}
)
.mutation(async ({ input: { factTriggerId }, ctx: { db } }) => {
await db.factTriggers.delete(factTriggerId);
.mutation(async ({ input: { factTriggerId }, ctx: { dbClient } }) => {
await dbClient
.deleteFrom("fact_triggers")
.where("id", "=", factTriggerId)
.execute();
return { ok: true };
}),
update: publicProcedure
@ -83,10 +99,16 @@ export const factTriggers = router({
content: string;
}
)
.mutation(async ({ input: { factTriggerId, content }, ctx: { db } }) => {
db.factTriggers.update(factTriggerId, { content });
.mutation(
async ({ input: { factTriggerId, content }, ctx: { dbClient } }) => {
await dbClient
.updateTable("fact_triggers")
.set({ content })
.where("id", "=", factTriggerId)
.execute();
return { ok: true };
}),
}
),
generateFromFact: publicProcedure
.input(
(x) =>

@ -1,7 +1,7 @@
import type { TSchema } from "@sinclair/typebox";
import { TypeCompiler } from "@sinclair/typebox/compiler";
import { initTRPC, TRPCError } from "@trpc/server";
import type { getDb, getDbClient } from "../../database/postgres";
import type { getDbClient } from "../../database/postgres";
import type { getOpenrouter } from "@server/provider.js";
import type { JWT } from "@auth/core/jwt";
@ -12,7 +12,6 @@ import type { JWT } from "@auth/core/jwt";
const t = initTRPC
.context<
object & {
db: ReturnType<typeof getDb>;
dbClient: ReturnType<typeof getDbClient>;
openrouter: ReturnType<typeof getOpenrouter>;
jwt?: JWT | null;

Loading…
Cancel
Save