Skip to content
Snippets Groups Projects
Unverified Commit 16f00681 authored by Aman Rao's avatar Aman Rao Committed by GitHub
Browse files

chore: update azure cosmos db no sql vector store (#1503)

parent 1054c338
No related branches found
No related tags found
No related merge requests found
...@@ -83,14 +83,6 @@ async function query() { ...@@ -83,14 +83,6 @@ async function query() {
}); });
} }
// configure the Azure CosmosDB NoSQL Vector Store
const dbConfig: AzureCosmosDBNoSQLConfig = {
client: cosmosClient,
databaseName,
containerName,
flatMetadata: false,
};
// use Azure CosmosDB as a vectorStore, docStore, and indexStore // use Azure CosmosDB as a vectorStore, docStore, and indexStore
const { vectorStore, docStore, indexStore } = await initializeStores(); const { vectorStore, docStore, indexStore } = await initializeStores();
......
...@@ -55,6 +55,16 @@ export interface AzureCosmosDBNoSQLConfig ...@@ -55,6 +55,16 @@ export interface AzureCosmosDBNoSQLConfig
readonly flatMetadata?: boolean; readonly flatMetadata?: boolean;
readonly idKey?: string; readonly idKey?: string;
} }
/**
* Query options for the `AzureCosmosDBNoSQLVectorStore.query` method.
* @property includeEmbeddings - Whether to include the embeddings in the result. Default false
* @property includeVectorDistance - Whether to include the vector distance in the result. Default true
* @property whereClause - The where clause to use in the query. While writing this clause, use `c` as the alias for the container and do not include the `WHERE` keyword.
*/
export interface AzureCosmosQueryOptions {
includeVectorDistance?: boolean;
whereClause?: string;
}
const USER_AGENT_SUFFIX = "llamaindex-cdbnosql-vectorstore-javascript"; const USER_AGENT_SUFFIX = "llamaindex-cdbnosql-vectorstore-javascript";
...@@ -98,6 +108,22 @@ function parseConnectionString(connectionString: string): { ...@@ -98,6 +108,22 @@ function parseConnectionString(connectionString: string): {
return { endpoint, key: accountKey }; return { endpoint, key: accountKey };
} }
/**
* utility function to build the query string for the CosmosDB query
*/
function queryBuilder(options: AzureCosmosQueryOptions): string {
let initialQuery =
"SELECT TOP @k c[@id] as id, c[@text] as text, c[@metadata] as metadata";
if (options.includeVectorDistance !== false) {
initialQuery += `, VectorDistance(c[@embeddingKey],@embedding) AS SimilarityScore`;
}
initialQuery += ` FROM c`;
if (options.whereClause) {
initialQuery += ` WHERE ${options.whereClause}`;
}
initialQuery += ` ORDER BY VectorDistance(c[@embeddingKey],@embedding)`;
return initialQuery;
}
export class AzureCosmosDBNoSqlVectorStore extends BaseVectorStore { export class AzureCosmosDBNoSqlVectorStore extends BaseVectorStore {
storesText: boolean = true; storesText: boolean = true;
...@@ -334,21 +360,25 @@ export class AzureCosmosDBNoSqlVectorStore extends BaseVectorStore { ...@@ -334,21 +360,25 @@ export class AzureCosmosDBNoSqlVectorStore extends BaseVectorStore {
*/ */
async query( async query(
query: VectorStoreQuery, query: VectorStoreQuery,
options?: object, options: AzureCosmosQueryOptions = {},
): Promise<VectorStoreQueryResult> { ): Promise<VectorStoreQueryResult> {
await this.initialize(); await this.initialize();
if (!query.queryEmbedding || query.queryEmbedding.length === 0) {
throw new Error(
"queryEmbedding is required for AzureCosmosDBNoSqlVectorStore query",
);
}
const params = { const params = {
vector: query.queryEmbedding!, vector: query.queryEmbedding!,
k: query.similarityTopK, k: query.similarityTopK,
}; };
const builtQuery = queryBuilder(options);
const nodes: BaseNode[] = []; const nodes: BaseNode[] = [];
const ids: string[] = []; const ids: string[] = [];
const similarities: number[] = []; const similarities: number[] = [];
const queryResults = await this.container.items const queryResults = await this.container.items
.query({ .query({
query: query: builtQuery,
"SELECT TOP @k c[@id] as id, c[@text] as text, c[@metadata] as metadata, VectorDistance(c[@embeddingKey],@embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c[@embeddingKey],@embedding)",
parameters: [ parameters: [
{ name: "@k", value: params.k }, { name: "@k", value: params.k },
{ name: "@id", value: this.idKey }, { name: "@id", value: this.idKey },
......
...@@ -14,9 +14,10 @@ import { ...@@ -14,9 +14,10 @@ import {
Settings, Settings,
VectorStoreQueryMode, VectorStoreQueryMode,
type AzureCosmosDBNoSQLConfig, type AzureCosmosDBNoSQLConfig,
type AzureCosmosQueryOptions,
type VectorStoreQueryResult, type VectorStoreQueryResult,
} from "llamaindex"; } from "llamaindex";
import { beforeEach, describe, expect, it } from "vitest"; import { beforeAll, describe, expect, it } from "vitest";
dotenv.config(); dotenv.config();
/* /*
* To run this test, you need have an Azure Cosmos DB for NoSQL instance * To run this test, you need have an Azure Cosmos DB for NoSQL instance
...@@ -64,7 +65,10 @@ Settings.llm = new OpenAI(llmInit); ...@@ -64,7 +65,10 @@ Settings.llm = new OpenAI(llmInit);
Settings.embedModel = new OpenAIEmbedding(embedModelInit); Settings.embedModel = new OpenAIEmbedding(embedModelInit);
// This test is skipped because it requires an Azure Cosmos DB instance and OpenAI API keys // This test is skipped because it requires an Azure Cosmos DB instance and OpenAI API keys
describe.skip("AzureCosmosDBNoSQLVectorStore", () => { describe.skip("AzureCosmosDBNoSQLVectorStore", () => {
beforeEach(async () => { let vectorStore: AzureCosmosDBNoSqlVectorStore;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let embeddings: any = [];
beforeAll(async () => {
if (process.env.AZURE_COSMOSDB_NOSQL_CONNECTION_STRING) { if (process.env.AZURE_COSMOSDB_NOSQL_CONNECTION_STRING) {
client = new CosmosClient( client = new CosmosClient(
process.env.AZURE_COSMOSDB_NOSQL_CONNECTION_STRING, process.env.AZURE_COSMOSDB_NOSQL_CONNECTION_STRING,
...@@ -79,15 +83,12 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { ...@@ -79,15 +83,12 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => {
"Please set the environment variable AZURE_COSMOSDB_NOSQL_CONNECTION_STRING or AZURE_COSMOSDB_NOSQL_ENDPOINT", "Please set the environment variable AZURE_COSMOSDB_NOSQL_CONNECTION_STRING or AZURE_COSMOSDB_NOSQL_ENDPOINT",
); );
} }
// Make sure the database does not exists // Make sure the database does not exists
try { try {
await client.database(DATABASE_NAME).delete(); await client.database(DATABASE_NAME).delete();
} catch { } catch {
// Ignore error if the database does not exist // Ignore error if the database does not exist
} }
});
it("perform query", async () => {
const config: AzureCosmosDBNoSQLConfig = { const config: AzureCosmosDBNoSQLConfig = {
idKey: "name", idKey: "name",
textKey: "customText", textKey: "customText",
...@@ -134,9 +135,9 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { ...@@ -134,9 +135,9 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => {
}, },
}; };
const vectorStore = new AzureCosmosDBNoSqlVectorStore(config); vectorStore = new AzureCosmosDBNoSqlVectorStore(config);
const embeddings = await Settings.embedModel.getTextEmbeddings([ embeddings = await Settings.embedModel.getTextEmbeddings([
"This book is about politics", "This book is about politics",
"Cats sleeps a lot.", "Cats sleeps a lot.",
"Sandwiches taste good.", "Sandwiches taste good.",
...@@ -150,28 +151,29 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { ...@@ -150,28 +151,29 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => {
id_: "1", id_: "1",
text: "This book is about politics", text: "This book is about politics",
embedding: embeddings[0], embedding: embeddings[0],
metadata: { key: "politics" }, metadata: { key: "politics", number: 1 },
}), }),
new Document({ new Document({
id_: "2", id_: "2",
text: "Cats sleeps a lot.", text: "Cats sleeps a lot.",
embedding: embeddings[1], embedding: embeddings[1],
metadata: { key: "cats" }, metadata: { key: "cats", number: 2 },
}), }),
new Document({ new Document({
id_: "3", id_: "3",
text: "Sandwiches taste good.", text: "Sandwiches taste good.",
embedding: embeddings[2], embedding: embeddings[2],
metadata: { key: "sandwiches" }, metadata: { key: "sandwiches", number: 3 },
}), }),
new Document({ new Document({
id_: "4", id_: "4",
text: "The house is open", text: "The house is open",
embedding: embeddings[3], embedding: embeddings[3],
metadata: { key: "house" }, metadata: { key: "house", number: 4 },
}), }),
]); ]);
});
it("perform query", async () => {
const results: VectorStoreQueryResult = await vectorStore.query({ const results: VectorStoreQueryResult = await vectorStore.query({
queryEmbedding: embeddings[4] || [], queryEmbedding: embeddings[4] || [],
similarityTopK: 1, similarityTopK: 1,
...@@ -179,5 +181,62 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => { ...@@ -179,5 +181,62 @@ describe.skip("AzureCosmosDBNoSQLVectorStore", () => {
}); });
expect(results.ids.length).toEqual(1); expect(results.ids.length).toEqual(1);
expect(results.ids[0]).toEqual("3"); expect(results.ids[0]).toEqual("3");
expect(results.similarities).toBeDefined();
expect(results.similarities[0]).toBeDefined();
}, 1000000);
it("perform query with where clause", async () => {
const options: AzureCosmosQueryOptions = {
whereClause: "c.customMetadata.number > 3",
};
const results: VectorStoreQueryResult = await vectorStore.query(
{
queryEmbedding: embeddings[4] || [],
similarityTopK: 1,
mode: VectorStoreQueryMode.DEFAULT,
},
options,
);
expect(results.ids.length).toEqual(1);
expect(results.ids[0]).toEqual("4");
expect(results.similarities).toBeDefined();
expect(results.similarities[0]).toBeDefined();
}, 1000000);
it("perform query with includeVectorDistance false", async () => {
const options: AzureCosmosQueryOptions = {
includeVectorDistance: false,
};
const results: VectorStoreQueryResult = await vectorStore.query(
{
queryEmbedding: embeddings[4] || [],
similarityTopK: 1,
mode: VectorStoreQueryMode.DEFAULT,
},
options,
);
expect(results.ids.length).toEqual(1);
expect(results.ids[0]).toEqual("3");
expect(results.similarities).toBeDefined();
expect(results.similarities[0]).toBeUndefined();
}, 1000000);
it("perform query with includeVectorDistance false and whereClause", async () => {
const options: AzureCosmosQueryOptions = {
includeVectorDistance: false,
whereClause: "c.customMetadata.number > 3",
};
const results: VectorStoreQueryResult = await vectorStore.query(
{
queryEmbedding: embeddings[4] || [],
similarityTopK: 1,
mode: VectorStoreQueryMode.DEFAULT,
},
options,
);
expect(results.ids.length).toEqual(1);
expect(results.ids[0]).toEqual("4");
expect(results.similarities).toBeDefined();
expect(results.similarities[0]).toBeUndefined();
}, 1000000); }, 1000000);
}); });
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import type { BaseNode } from "@llamaindex/core/schema"; import type { BaseNode } from "@llamaindex/core/schema";
import { beforeEach, describe, expect, it, vi } from "vitest"; import { beforeEach, describe, expect, it, vi } from "vitest";
import { VectorStoreQueryMode } from "../../src/vector-store.js";
import { TestableAzureCosmosDBNoSqlVectorStore } from "../mocks/TestableAzureCosmosDBNoSqlVectorStore.js"; import { TestableAzureCosmosDBNoSqlVectorStore } from "../mocks/TestableAzureCosmosDBNoSqlVectorStore.js";
import { createMockClient } from "../utility/mockCosmosClient.js"; // Import the mock client import { createMockClient } from "../utility/mockCosmosClient.js"; // Import the mock client
...@@ -95,4 +96,27 @@ describe("AzureCosmosDBNoSqlVectorStore Tests", () => { ...@@ -95,4 +96,27 @@ describe("AzureCosmosDBNoSqlVectorStore Tests", () => {
expect(client.databases.containers.items.create).toHaveBeenCalledTimes(2); expect(client.databases.containers.items.create).toHaveBeenCalledTimes(2);
expect(result).toEqual(["node-0", "node-1"]); expect(result).toEqual(["node-0", "node-1"]);
}); });
it("should throw error if no query embedding is provided", async () => {
const client = createMockClient();
const store = new TestableAzureCosmosDBNoSqlVectorStore({
client: client as any,
endpoint: "https://example.com",
idKey: "id",
textKey: "text",
metadataKey: "metadata",
});
expect(store).toBeDefined();
await expect(
store.query({
queryEmbedding: [],
similarityTopK: 4,
mode: VectorStoreQueryMode.DEFAULT,
}),
).rejects.toThrowError(
"queryEmbedding is required for AzureCosmosDBNoSqlVectorStore query",
);
});
}); });
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment