diff --git a/.changeset/twenty-dolphins-marry.md b/.changeset/twenty-dolphins-marry.md new file mode 100644 index 0000000000000000000000000000000000000000..ce79f4316b1ae1eaf3b83ad807dcabbdd5fe89c5 --- /dev/null +++ b/.changeset/twenty-dolphins-marry.md @@ -0,0 +1,9 @@ +--- +"llamaindex": patch +--- + +feat: update `PGVectorStore` + +- move constructor parameter `config.user` | `config.database` | `config.password` | `config.connectionString` into `config.clientConfig` +- if you pass `pg.Client` or `pg.Pool` instance to `PGVectorStore`, move it to `config.client`, setting `config.shouldConnect` to false if it's already connected +- default value of `PGVectorStore.collection` is now `"data"` instead of `""` (empty string) diff --git a/examples/pg-vector-store/load-docs.ts b/examples/pg-vector-store/load-docs.ts index 32e140041d832b8d4b27608767366d6a0f42b5b8..0c411b3a66f2ddc7ea3f7db58d7bebaf191ac2b1 100755 --- a/examples/pg-vector-store/load-docs.ts +++ b/examples/pg-vector-store/load-docs.ts @@ -40,7 +40,11 @@ async function main(args: any) { const rdr = new SimpleDirectoryReader(callback); const docs = await rdr.loadData({ directoryPath: sourceDir }); - const pgvs = new PGVectorStore(); + const pgvs = new PGVectorStore({ + clientConfig: { + connectionString: process.env.PG_CONNECTION_STRING, + }, + }); pgvs.setCollection(sourceDir); await pgvs.clearCollection(); diff --git a/examples/pg-vector-store/query.ts b/examples/pg-vector-store/query.ts index 96d6ed9bd8c45a8cddc6dec8cb4ab677d7a75654..66cddbe4ff49d94c3fca0079f6a541cb0a08a641 100755 --- a/examples/pg-vector-store/query.ts +++ b/examples/pg-vector-store/query.ts @@ -7,7 +7,11 @@ async function main() { }); try { - const pgvs = new PGVectorStore(); + const pgvs = new PGVectorStore({ + clientConfig: { + connectionString: process.env.PG_CONNECTION_STRING, + }, + }); // Optional - set your collection name, default is no filter on this field. // pgvs.setCollection(); diff --git a/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts b/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts index 4bb78b37820dee7daf5e18333160c5a740134076..5d05435abc4748247f9f198a2dfb7594ef7913d9 100644 --- a/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts +++ b/packages/llamaindex/e2e/node/vector-store/pg-vector-store.e2e.ts @@ -9,43 +9,54 @@ import { registerTypes } from "pgvector/pg"; config({ path: [".env.local", ".env", ".env.ci"] }); -let pgClient: pg.Client | pg.Pool; -test.afterEach(async () => { - await pgClient.end(); -}); - const pgConfig = { user: process.env.POSTGRES_USER ?? "user", password: process.env.POSTGRES_PASSWORD ?? "password", database: "llamaindex_node_test", }; -await test("init with client", async () => { - pgClient = new pg.Client(pgConfig); +await test("init with client", async (t) => { + const pgClient = new pg.Client(pgConfig); await pgClient.connect(); await pgClient.query("CREATE EXTENSION IF NOT EXISTS vector"); await registerTypes(pgClient); - const vectorStore = new PGVectorStore(pgClient); + t.after(async () => { + await pgClient.end(); + }); + const vectorStore = new PGVectorStore({ + client: pgClient, + shouldConnect: false, + }); assert.deepStrictEqual(await vectorStore.client(), pgClient); }); -await test("init with pool", async () => { - pgClient = new pg.Pool(pgConfig); +await test("init with pool", async (t) => { + const pgClient = new pg.Pool(pgConfig); await pgClient.query("CREATE EXTENSION IF NOT EXISTS vector"); const client = await pgClient.connect(); + await client.query("CREATE EXTENSION IF NOT EXISTS vector"); await registerTypes(client); - const vectorStore = new PGVectorStore(client); + t.after(async () => { + client.release(); + await pgClient.end(); + }); + const vectorStore = new PGVectorStore({ + shouldConnect: false, + client, + }); assert.deepStrictEqual(await vectorStore.client(), client); - client.release(); }); -await test("init without client", async () => { - const vectorStore = new PGVectorStore(pgConfig); - pgClient = (await vectorStore.client()) as pg.Client; +await test("init without client", async (t) => { + const vectorStore = new PGVectorStore({ clientConfig: pgConfig }); + const pgClient = (await vectorStore.client()) as pg.Client; + t.after(async () => { + await pgClient.end(); + }); assert.notDeepStrictEqual(pgClient, undefined); }); -await test("simple node", async () => { +await test("simple node", async (t) => { const dimensions = 3; const schemaName = "llamaindex_vector_store_test_" + Math.random().toString(36).substring(7); @@ -56,10 +67,14 @@ await test("simple node", async () => { embedding: [0.1, 0.2, 0.3], }); const vectorStore = new PGVectorStore({ - ...pgConfig, + clientConfig: pgConfig, dimensions, schemaName, }); + const pgClient = (await vectorStore.client()) as pg.Client; + t.after(async () => { + await pgClient.end(); + }); await vectorStore.add([node]); @@ -89,6 +104,4 @@ await test("simple node", async () => { }); assert.deepStrictEqual(result.nodes, []); } - - pgClient = (await vectorStore.client()) as pg.Client; }); diff --git a/packages/llamaindex/src/vector-store/PGVectorStore.ts b/packages/llamaindex/src/vector-store/PGVectorStore.ts index 70eef0112485bf0e2b62ef0cb239b12dabc69c87..d2ac5bfc0bb8e9efb365f96d20a9430ce943fca0 100644 --- a/packages/llamaindex/src/vector-store/PGVectorStore.ts +++ b/packages/llamaindex/src/vector-store/PGVectorStore.ts @@ -14,25 +14,44 @@ import { import { escapeLikeString } from "./utils.js"; import type { BaseEmbedding } from "@llamaindex/core/embeddings"; +import { DEFAULT_COLLECTION } from "@llamaindex/core/global"; import type { BaseNode, Metadata } from "@llamaindex/core/schema"; import { Document, MetadataMode } from "@llamaindex/core/schema"; export const PGVECTOR_SCHEMA = "public"; export const PGVECTOR_TABLE = "llamaindex_embedding"; +export const DEFAULT_DIMENSIONS = 1536; -export type PGVectorStoreConfig = Pick< - pg.ClientConfig, - "user" | "database" | "password" | "connectionString" -> & { +type PGVectorStoreBaseConfig = { schemaName?: string | undefined; tableName?: string | undefined; dimensions?: number | undefined; embedModel?: BaseEmbedding | undefined; }; +export type PGVectorStoreConfig = PGVectorStoreBaseConfig & + ( + | { + /** + * Client configuration options for the pg client. + * + * {@link https://node-postgres.com/apis/client#new-client PostgresSQL Client API} + */ + clientConfig: pg.ClientConfig; + } + | { + /** + * A pg client or pool client instance. + * If provided, make sure it is not connected to the database yet, or it will throw an error. + */ + shouldConnect?: boolean | undefined; + client: pg.Client | pg.PoolClient; + } + ); + /** * Provides support for writing and querying vector data in Postgres. - * Note: Can't be used with data created using the Python version of the vector store (https://docs.llamaindex.ai/en/stable/examples/vector_stores/postgres.html) + * Note: Can't be used with data created using the Python version of the vector store (https://docs.llamaindex.ai/en/stable/examples/vector_stores/postgres/) */ export class PGVectorStore extends VectorStoreBase @@ -40,52 +59,26 @@ export class PGVectorStore { storesText: boolean = true; - private collection: string = ""; - private schemaName: string = PGVECTOR_SCHEMA; - private tableName: string = PGVECTOR_TABLE; - - private user: pg.ClientConfig["user"] | undefined = undefined; - private password: pg.ClientConfig["password"] | undefined = undefined; - private database: pg.ClientConfig["database"] | undefined = undefined; - private connectionString: pg.ClientConfig["connectionString"] | undefined = - undefined; - - private dimensions: number = 1536; - - private db?: pg.ClientBase; - - /** - * Constructs a new instance of the PGVectorStore - * - * If the `connectionString` is not provided the following env variables are - * used to connect to the DB: - * PGHOST=your database host - * PGUSER=your database user - * PGPASSWORD=your database password - * PGDATABASE=your database name - * PGPORT=your database port - */ - constructor(configOrClient?: PGVectorStoreConfig | pg.ClientBase) { - // We cannot import pg from top level, it might have side effects - // so we only check if the config.connect function exists - if ( - configOrClient && - "connect" in configOrClient && - typeof configOrClient.connect === "function" - ) { - const db = configOrClient as pg.ClientBase; - super(); - this.db = db; + private collection: string = DEFAULT_COLLECTION; + private readonly schemaName: string = PGVECTOR_SCHEMA; + private readonly tableName: string = PGVECTOR_TABLE; + private readonly dimensions: number = DEFAULT_DIMENSIONS; + + private isDBConnected: boolean = false; + private db: pg.ClientBase | null = null; + private readonly clientConfig: pg.ClientConfig | null = null; + + constructor(config: PGVectorStoreConfig) { + super(config?.embedModel); + this.schemaName = config?.schemaName ?? PGVECTOR_SCHEMA; + this.tableName = config?.tableName ?? PGVECTOR_TABLE; + this.dimensions = config?.dimensions ?? DEFAULT_DIMENSIONS; + if ("clientConfig" in config) { + this.clientConfig = config.clientConfig; } else { - const config = configOrClient as PGVectorStoreConfig; - super(config?.embedModel); - this.schemaName = config?.schemaName ?? PGVECTOR_SCHEMA; - this.tableName = config?.tableName ?? PGVECTOR_TABLE; - this.user = config?.user; - this.password = config?.password; - this.database = config?.database; - this.connectionString = config?.connectionString; - this.dimensions = config?.dimensions ?? 1536; + this.isDBConnected = + config.shouldConnect !== undefined ? !config.shouldConnect : false; + this.db = config.client; } } @@ -113,39 +106,41 @@ export class PGVectorStore private async getDb(): Promise<pg.ClientBase> { if (!this.db) { - try { - const pg = await import("pg"); - const { Client } = pg.default ? pg.default : pg; - - const { registerType } = await import("pgvector/pg"); - // Create DB connection - // Read connection params from env - see comment block above - const db = new Client({ - user: this.user, - password: this.password, - database: this.database, - connectionString: this.connectionString, - }); - await db.connect(); - - // Check vector extension - await db.query("CREATE EXTENSION IF NOT EXISTS vector"); - await registerType(db); - - // All good? Keep the connection reference - this.db = db; - } catch (err) { - console.error(err); - return Promise.reject(err instanceof Error ? err : new Error(`${err}`)); - } + const pg = await import("pg"); + const { Client } = pg.default ? pg.default : pg; + + const { registerTypes } = await import("pgvector/pg"); + // Create DB connection + // Read connection params from env - see comment block above + const db = new Client({ + ...this.clientConfig, + }); + + await db.connect(); + this.isDBConnected = true; + + // Check vector extension + await db.query("CREATE EXTENSION IF NOT EXISTS vector"); + await registerTypes(db); + + // All good? Keep the connection reference + this.db = db; } - const db = this.db; + if (this.db && !this.isDBConnected) { + await this.db.connect(); + this.isDBConnected = true; + } + + this.db.on("end", () => { + // Connection closed + this.isDBConnected = false; + }); // Check schema, table(s), index(es) - await this.checkSchema(db); + await this.checkSchema(this.db); - return Promise.resolve(this.db); + return this.db; } private async checkSchema(db: pg.ClientBase) {