diff --git a/examples/cosmosdb/queryVectorData.ts b/examples/cosmosdb/queryVectorData.ts index 29f875b5167db165a98968e0fbd585edd006efe9..13e9604831369c567af7781d9ccc49001bc66f63 100644 --- a/examples/cosmosdb/queryVectorData.ts +++ b/examples/cosmosdb/queryVectorData.ts @@ -83,14 +83,6 @@ async function query() { }); } - // configure the Azure CosmosDB NoSQL Vector Store - const dbConfig: AzureCosmosDBNoSQLConfig = { - client: cosmosClient, - databaseName, - containerName, - flatMetadata: false, - }; - // use Azure CosmosDB as a vectorStore, docStore, and indexStore const { vectorStore, docStore, indexStore } = await initializeStores(); diff --git a/packages/llamaindex/src/vector-store/AzureCosmosDBNoSqlVectorStore.ts b/packages/llamaindex/src/vector-store/AzureCosmosDBNoSqlVectorStore.ts index 8aef9b45988611943aa966b5b476b19c5bfc869e..29d01de50469b2d6ff79da44a1f27f5b76ba54a2 100644 --- a/packages/llamaindex/src/vector-store/AzureCosmosDBNoSqlVectorStore.ts +++ b/packages/llamaindex/src/vector-store/AzureCosmosDBNoSqlVectorStore.ts @@ -55,6 +55,16 @@ export interface AzureCosmosDBNoSQLConfig readonly flatMetadata?: boolean; readonly idKey?: string; } +/** + * Query options for the `AzureCosmosDBNoSQLVectorStore.query` method. + * @property includeEmbeddings - Whether to include the embeddings in the result. Default false + * @property includeVectorDistance - Whether to include the vector distance in the result. Default true + * @property whereClause - The where clause to use in the query. While writing this clause, use `c` as the alias for the container and do not include the `WHERE` keyword. + */ +export interface AzureCosmosQueryOptions { + includeVectorDistance?: boolean; + whereClause?: string; +} const USER_AGENT_SUFFIX = "llamaindex-cdbnosql-vectorstore-javascript"; @@ -98,6 +108,22 @@ function parseConnectionString(connectionString: string): { return { endpoint, key: accountKey }; } +/** + * utility function to build the query string for the CosmosDB query + */ +function queryBuilder(options: AzureCosmosQueryOptions): string { + let initialQuery = + "SELECT TOP @k c[@id] as id, c[@text] as text, c[@metadata] as metadata"; + if (options.includeVectorDistance !== false) { + initialQuery += `, VectorDistance(c[@embeddingKey],@embedding) AS SimilarityScore`; + } + initialQuery += ` FROM c`; + if (options.whereClause) { + initialQuery += ` WHERE ${options.whereClause}`; + } + initialQuery += ` ORDER BY VectorDistance(c[@embeddingKey],@embedding)`; + return initialQuery; +} export class AzureCosmosDBNoSqlVectorStore extends BaseVectorStore { storesText: boolean = true; @@ -334,21 +360,25 @@ export class AzureCosmosDBNoSqlVectorStore extends BaseVectorStore { */ async query( query: VectorStoreQuery, - options?: object, + options: AzureCosmosQueryOptions = {}, ): Promise<VectorStoreQueryResult> { await this.initialize(); + if (!query.queryEmbedding || query.queryEmbedding.length === 0) { + throw new Error( + "queryEmbedding is required for AzureCosmosDBNoSqlVectorStore query", + ); + } const params = { vector: query.queryEmbedding!, k: query.similarityTopK, }; - + const builtQuery = queryBuilder(options); const nodes: BaseNode[] = []; const ids: string[] = []; const similarities: number[] = []; const queryResults = await this.container.items .query({ - query: - "SELECT TOP @k c[@id] as id, c[@text] as text, c[@metadata] as metadata, VectorDistance(c[@embeddingKey],@embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c[@embeddingKey],@embedding)", + query: builtQuery, parameters: [ { name: "@k", value: params.k }, { name: "@id", value: this.idKey }, diff --git a/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.int.test.ts b/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.int.test.ts index c93433781a2b6d3b5461409614458ceffb6c2d1e..a60cf1c09c91d2641f69a6659225ab3e235a0037 100644 --- a/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.int.test.ts +++ b/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.int.test.ts @@ -14,9 +14,10 @@ import { Settings, VectorStoreQueryMode, type AzureCosmosDBNoSQLConfig, + type AzureCosmosQueryOptions, type VectorStoreQueryResult, } from "llamaindex"; -import { beforeEach, describe, expect, it } from "vitest"; +import { beforeAll, describe, expect, it } from "vitest"; dotenv.config(); /* * To run this test, you need have an Azure Cosmos DB for NoSQL instance @@ -64,7 +65,10 @@ Settings.llm = new OpenAI(llmInit); Settings.embedModel = new OpenAIEmbedding(embedModelInit); // This test is skipped because it requires an Azure Cosmos DB instance and OpenAI API keys describe.skip("AzureCosmosDBNoSQLVectorStore", () => { - beforeEach(async () => { + let vectorStore: AzureCosmosDBNoSqlVectorStore; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let embeddings: any = []; + beforeAll(async () => { if (process.env.AZURE_COSMOSDB_NOSQL_CONNECTION_STRING) { client = new CosmosClient( process.env.AZURE_COSMOSDB_NOSQL_CONNECTION_STRING, @@ -79,15 +83,12 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { "Please set the environment variable AZURE_COSMOSDB_NOSQL_CONNECTION_STRING or AZURE_COSMOSDB_NOSQL_ENDPOINT", ); } - // Make sure the database does not exists try { await client.database(DATABASE_NAME).delete(); } catch { // Ignore error if the database does not exist } - }); - it("perform query", async () => { const config: AzureCosmosDBNoSQLConfig = { idKey: "name", textKey: "customText", @@ -134,9 +135,9 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { }, }; - const vectorStore = new AzureCosmosDBNoSqlVectorStore(config); + vectorStore = new AzureCosmosDBNoSqlVectorStore(config); - const embeddings = await Settings.embedModel.getTextEmbeddings([ + embeddings = await Settings.embedModel.getTextEmbeddings([ "This book is about politics", "Cats sleeps a lot.", "Sandwiches taste good.", @@ -150,28 +151,29 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { id_: "1", text: "This book is about politics", embedding: embeddings[0], - metadata: { key: "politics" }, + metadata: { key: "politics", number: 1 }, }), new Document({ id_: "2", text: "Cats sleeps a lot.", embedding: embeddings[1], - metadata: { key: "cats" }, + metadata: { key: "cats", number: 2 }, }), new Document({ id_: "3", text: "Sandwiches taste good.", embedding: embeddings[2], - metadata: { key: "sandwiches" }, + metadata: { key: "sandwiches", number: 3 }, }), new Document({ id_: "4", text: "The house is open", embedding: embeddings[3], - metadata: { key: "house" }, + metadata: { key: "house", number: 4 }, }), ]); - + }); + it("perform query", async () => { const results: VectorStoreQueryResult = await vectorStore.query({ queryEmbedding: embeddings[4] || [], similarityTopK: 1, @@ -179,5 +181,62 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { }); expect(results.ids.length).toEqual(1); expect(results.ids[0]).toEqual("3"); + expect(results.similarities).toBeDefined(); + expect(results.similarities[0]).toBeDefined(); + }, 1000000); + + it("perform query with where clause", async () => { + const options: AzureCosmosQueryOptions = { + whereClause: "c.customMetadata.number > 3", + }; + const results: VectorStoreQueryResult = await vectorStore.query( + { + queryEmbedding: embeddings[4] || [], + similarityTopK: 1, + mode: VectorStoreQueryMode.DEFAULT, + }, + options, + ); + expect(results.ids.length).toEqual(1); + expect(results.ids[0]).toEqual("4"); + expect(results.similarities).toBeDefined(); + expect(results.similarities[0]).toBeDefined(); + }, 1000000); + + it("perform query with includeVectorDistance false", async () => { + const options: AzureCosmosQueryOptions = { + includeVectorDistance: false, + }; + const results: VectorStoreQueryResult = await vectorStore.query( + { + queryEmbedding: embeddings[4] || [], + similarityTopK: 1, + mode: VectorStoreQueryMode.DEFAULT, + }, + options, + ); + expect(results.ids.length).toEqual(1); + expect(results.ids[0]).toEqual("3"); + expect(results.similarities).toBeDefined(); + expect(results.similarities[0]).toBeUndefined(); + }, 1000000); + + it("perform query with includeVectorDistance false and whereClause", async () => { + const options: AzureCosmosQueryOptions = { + includeVectorDistance: false, + whereClause: "c.customMetadata.number > 3", + }; + const results: VectorStoreQueryResult = await vectorStore.query( + { + queryEmbedding: embeddings[4] || [], + similarityTopK: 1, + mode: VectorStoreQueryMode.DEFAULT, + }, + options, + ); + expect(results.ids.length).toEqual(1); + expect(results.ids[0]).toEqual("4"); + expect(results.similarities).toBeDefined(); + expect(results.similarities[0]).toBeUndefined(); }, 1000000); }); diff --git a/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.test.ts b/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.test.ts index 88a90802e5537dd623a766dc6a947f149035072c..20ba57705d89538b13d81091a6721f60fe4ce81e 100644 --- a/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.test.ts +++ b/packages/llamaindex/tests/vector-stores/AzureCosmosDBNoSqlVectorStore.test.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { BaseNode } from "@llamaindex/core/schema"; import { beforeEach, describe, expect, it, vi } from "vitest"; +import { VectorStoreQueryMode } from "../../src/vector-store.js"; import { TestableAzureCosmosDBNoSqlVectorStore } from "../mocks/TestableAzureCosmosDBNoSqlVectorStore.js"; import { createMockClient } from "../utility/mockCosmosClient.js"; // Import the mock client @@ -95,4 +96,27 @@ describe("AzureCosmosDBNoSqlVectorStore Tests", () => { expect(client.databases.containers.items.create).toHaveBeenCalledTimes(2); expect(result).toEqual(["node-0", "node-1"]); }); + + it("should throw error if no query embedding is provided", async () => { + const client = createMockClient(); + const store = new TestableAzureCosmosDBNoSqlVectorStore({ + client: client as any, + endpoint: "https://example.com", + idKey: "id", + textKey: "text", + metadataKey: "metadata", + }); + + expect(store).toBeDefined(); + + await expect( + store.query({ + queryEmbedding: [], + similarityTopK: 4, + mode: VectorStoreQueryMode.DEFAULT, + }), + ).rejects.toThrowError( + "queryEmbedding is required for AzureCosmosDBNoSqlVectorStore query", + ); + }); });