From 630b4255454b38e2431906556cd0457e75150cdf Mon Sep 17 00:00:00 2001 From: Aman Rao <54672230+amanrao23@users.noreply.github.com> Date: Fri, 15 Nov 2024 11:00:55 +0530 Subject: [PATCH] feat: add Azure CosmosDB NoSql Chat store (#1490) Co-authored-by: Alex Yang <himself65@outlook.com> --- .changeset/swift-dodos-develop.md | 5 + .../chatStore/AzureCosmosNoSqlChatStore.ts | 345 ++++++++++++++++++ packages/llamaindex/src/storage/index.ts | 1 + 3 files changed, 351 insertions(+) create mode 100644 .changeset/swift-dodos-develop.md create mode 100644 packages/llamaindex/src/storage/chatStore/AzureCosmosNoSqlChatStore.ts diff --git a/.changeset/swift-dodos-develop.md b/.changeset/swift-dodos-develop.md new file mode 100644 index 000000000..28d1767c8 --- /dev/null +++ b/.changeset/swift-dodos-develop.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: add Azure CosmosDB NoSql Chat store diff --git a/packages/llamaindex/src/storage/chatStore/AzureCosmosNoSqlChatStore.ts b/packages/llamaindex/src/storage/chatStore/AzureCosmosNoSqlChatStore.ts new file mode 100644 index 000000000..d6ff7bbae --- /dev/null +++ b/packages/llamaindex/src/storage/chatStore/AzureCosmosNoSqlChatStore.ts @@ -0,0 +1,345 @@ +import { CosmosClient, type Container, type Database } from "@azure/cosmos"; +import { DefaultAzureCredential, type TokenCredential } from "@azure/identity"; +import type { + ChatMessage, + MessageContent, + MessageType, +} from "@llamaindex/core/llms"; +import { BaseChatStore } from "@llamaindex/core/storage/chat-store"; +import { getEnv } from "@llamaindex/env"; + +const USER_AGENT_SUFFIX = "llamaindex-cdbnosql-chatstore-javascript"; +const DEFAULT_CHAT_DATABASE = "ChatMessagesDB"; +const DEFAULT_CHAT_CONTAINER = "ChatMessagesContainer"; +const DEFAULT_OFFER_THROUGHPUT = 400; + +function parseConnectionString(connectionString: string): { + endpoint: string; + key: string; +} { + const parts = connectionString.split(";"); + let endpoint = ""; + let accountKey = ""; + + parts.forEach((part) => { + const [key, value] = part.split("="); + if (key && key.trim() === "AccountEndpoint") { + endpoint = value?.trim() ?? ""; + } else if ((key ?? "").trim() === "AccountKey") { + accountKey = value?.trim() ?? ""; + } + }); + + if (!endpoint || !accountKey) { + throw new Error( + "Invalid connection string: missing AccountEndpoint or AccountKey.", + ); + } + return { endpoint, key: accountKey }; +} + +export interface AzureCosmosChatDatabaseProperties { + throughput?: number; +} + +export interface AzureCosmosChatContainerProperties { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [key: string]: any; +} + +export interface AzureCosmosNoSqlChatStoreConfig { + cosmosClient?: CosmosClient; + dbName?: string; + containerName?: string; + userId?: string; + sessionId?: string; + cosmosContainerProperties?: AzureCosmosChatContainerProperties; + cosmosDatabaseProperties?: AzureCosmosChatDatabaseProperties; + ttlInSeconds?: number; +} + +export class AzureCosmosNoSqlChatStore< + AdditionalMessageOptions extends object = object, +> extends BaseChatStore<AdditionalMessageOptions> { + private userId: string; + private ttl: number; + private cosmosClient: CosmosClient; + private database!: Database; + private container!: Container; + private initPromise?: Promise<void>; + + private dbName: string; + private containerName: string; + private cosmosContainerProperties: AzureCosmosChatContainerProperties; + private cosmosDatabaseProperties: AzureCosmosChatDatabaseProperties; + private initialize: () => Promise<void>; + + constructor({ + cosmosClient, + dbName = DEFAULT_CHAT_DATABASE, + containerName = DEFAULT_CHAT_CONTAINER, + cosmosContainerProperties = { partitionKey: "/userId" }, + cosmosDatabaseProperties = {}, + ttlInSeconds = -1, + }: AzureCosmosNoSqlChatStoreConfig) { + super(); + if (!cosmosClient) { + throw new Error( + "CosmosClient is required for AzureCosmosDBNoSQLChatStore initialization", + ); + } + this.ttl = ttlInSeconds; + this.userId = cosmosContainerProperties.userId || "anonymous"; + this.cosmosClient = cosmosClient; + this.dbName = dbName; + this.containerName = containerName; + this.cosmosContainerProperties = cosmosContainerProperties; + this.cosmosDatabaseProperties = cosmosDatabaseProperties; + + this.initialize = () => { + if (this.initPromise === undefined) { + this.initPromise = this.init().catch((error) => { + console.error( + "Error during AzureCosmosDBNoSQLChatStore initialization", + error, + ); + }); + } + return this.initPromise; + }; + } + + client(): CosmosClient { + return this.cosmosClient; + } + + // Asynchronous initialization method to create database and container + private async init(): Promise<void> { + // Set default throughput if not provided + const throughput = + this.cosmosDatabaseProperties?.throughput || DEFAULT_OFFER_THROUGHPUT; + + // Create the database if it doesn't exist + const { database } = await this.cosmosClient.databases.createIfNotExists({ + id: this.dbName, + throughput, + }); + this.database = database; + + // Create the container if it doesn't exist + const { container } = await this.database.containers.createIfNotExists({ + id: this.containerName, + throughput: this.cosmosContainerProperties?.throughput, + partitionKey: "/userId", + indexingPolicy: this.cosmosContainerProperties?.indexingPolicy, + defaultTtl: this.ttl, + uniqueKeyPolicy: this.cosmosContainerProperties?.uniqueKeyPolicy, + conflictResolutionPolicy: + this.cosmosContainerProperties?.conflictResolutionPolicy, + computedProperties: this.cosmosContainerProperties?.computedProperties, + }); + this.container = container; + } + /** + * Static method for creating an instance using a connection string. + * If no connection string is provided, it will attempt to use the env variable `AZURE_COSMOSDB_NOSQL_CONNECTION_STRING` as connection string. + * @returns Instance of AzureCosmosNoSqlKVStore + */ + static fromConnectionString( + config: { + connectionString?: string; + } & AzureCosmosNoSqlChatStoreConfig = {}, + ): AzureCosmosNoSqlChatStore { + const cosmosConnectionString = + config.connectionString || + (getEnv("AZURE_COSMOSDB_NOSQL_CONNECTION_STRING") as string); + if (!cosmosConnectionString) { + throw new Error("Azure CosmosDB connection string must be provided"); + } + const { endpoint, key } = parseConnectionString(cosmosConnectionString); + const cosmosClient = new CosmosClient({ + endpoint, + key, + userAgentSuffix: USER_AGENT_SUFFIX, + }); + return new AzureCosmosNoSqlChatStore({ + ...config, + cosmosClient, + }); + } + + /** + * Static method for creating an instance using a account endpoint and key. + * If no endpoint and key is provided, it will attempt to use the env variable `AZURE_COSMOSDB_NOSQL_ACCOUNT_ENDPOINT` as enpoint and `AZURE_COSMOSDB_NOSQL_ACCOUNT_KEY` as key. + * @returns Instance of AzureCosmosNoSqlKVStore + */ + static fromAccountAndKey( + config: { + endpoint?: string; + key?: string; + } & AzureCosmosNoSqlChatStoreConfig = {}, + ): AzureCosmosNoSqlChatStore { + const cosmosEndpoint = + config.endpoint || + (getEnv("AZURE_COSMOSDB_NOSQL_ACCOUNT_ENDPOINT") as string); + const cosmosKey = + config.key || (getEnv("AZURE_COSMOSDB_NOSQL_ACCOUNT_KEY") as string); + + if (!cosmosEndpoint || !cosmosKey) { + throw new Error( + "Azure CosmosDB account endpoint and key must be provided", + ); + } + const cosmosClient = new CosmosClient({ + endpoint: cosmosEndpoint, + key: cosmosKey, + userAgentSuffix: USER_AGENT_SUFFIX, + }); + return new AzureCosmosNoSqlChatStore({ + ...config, + cosmosClient, + }); + } + + /** + * Static method for creating an instance using AAD token. + * If no endpoint and credentials are provided, it will attempt to use the env variable `AZURE_COSMOSDB_NOSQL_ACCOUNT_ENDPOINT` as endpoint and use DefaultAzureCredential() as credentials. + * @returns Instance of AzureCosmosNoSqlKVStore + */ + static fromAadToken( + config: { + endpoint?: string; + credential?: TokenCredential; + } & AzureCosmosNoSqlChatStoreConfig = {}, + ): AzureCosmosNoSqlChatStore { + const cosmosEndpoint = + config.endpoint || + (getEnv("AZURE_COSMOSDB_NOSQL_CONNECTION_STRING") as string); + + if (!cosmosEndpoint) { + throw new Error("Azure CosmosDB account endpoint must be provided"); + } + const credentials = config.credential ?? new DefaultAzureCredential(); + const cosmosClient = new CosmosClient({ + endpoint: cosmosEndpoint, + aadCredentials: credentials, + userAgentSuffix: USER_AGENT_SUFFIX, + }); + return new AzureCosmosNoSqlChatStore({ + ...config, + cosmosClient, + }); + } + + private convertToChatMessage( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + message: any, + ): ChatMessage<AdditionalMessageOptions> { + return { + content: message.content as MessageContent, + role: message.role as MessageType, + options: message.options as AdditionalMessageOptions, + } as ChatMessage<AdditionalMessageOptions>; + } + + private convertToCosmosMessage( + message: ChatMessage<AdditionalMessageOptions>, + ): // eslint-disable-next-line @typescript-eslint/no-explicit-any + any { + return { + content: message.content, + role: message.role, + options: message.options, + }; + } + + /** + * Set messages for a given key. + */ + async setMessages( + key: string, + messages: ChatMessage<AdditionalMessageOptions>[], + ): Promise<void> { + await this.initialize(); + const inputMessages = messages.map(this.convertToCosmosMessage); + await this.container.items.upsert({ + id: key, + messages: inputMessages, + userId: this.userId, + }); + } + + /** + * Get messages for a given key. + */ + async getMessages( + key: string, + ): Promise<ChatMessage<AdditionalMessageOptions>[]> { + await this.initialize(); + const res = await this.container.item(key, this.userId).read(); + const messageHistory = res?.resource?.messages ?? []; + const result = messageHistory.map(this.convertToChatMessage); + return result; + } + + /** + * Add a message for a given key. + */ + async addMessage( + key: string, + message: ChatMessage<AdditionalMessageOptions>, + idx?: number, + ): Promise<void> { + await this.initialize(); + const res = await this.container.item(key, this.userId).read(); + const messageHistory = res?.resource?.messages ?? []; + if (idx === undefined) { + messageHistory.push(this.convertToCosmosMessage(message)); + } else { + messageHistory.splice(idx, 0, this.convertToCosmosMessage(message)); + } + await this.setMessages(key, messageHistory); + } + + /** + * Deletes all messages for a given key. + */ + async deleteMessages(key: string): Promise<void> { + await this.initialize(); + try { + await this.container.item(key, this.userId).delete(); + // eslint-disable-next-line no-empty + } catch (e) {} + } + + /** + * Deletes one message at idx index for a given key. + */ + async deleteMessage(key: string, idx: number): Promise<void> { + await this.initialize(); + const res = await this.container.item(key, this.userId).read(); + const messageHistory = res?.resource?.messages ?? []; + if (idx >= 0 && idx < messageHistory.length) { + messageHistory.splice(idx, 1); + await this.setMessages(key, messageHistory); + } + } + + /** + * Get all keys. + */ + async getKeys(): Promise<IterableIterator<string>> { + await this.initialize(); + const result = await this.container.items + .query("Select c.id from c") + .fetchAll(); + const keys = result.resources.map((res: { id: string }) => res.id); + + function* keyGenerator(): IterableIterator<string> { + for (const key of keys) { + yield key; + } + } + return keyGenerator(); + } +} diff --git a/packages/llamaindex/src/storage/index.ts b/packages/llamaindex/src/storage/index.ts index 546eaefab..bb08b4980 100644 --- a/packages/llamaindex/src/storage/index.ts +++ b/packages/llamaindex/src/storage/index.ts @@ -2,6 +2,7 @@ export * from "@llamaindex/core/storage/chat-store"; export * from "@llamaindex/core/storage/doc-store"; export * from "@llamaindex/core/storage/index-store"; export * from "@llamaindex/core/storage/kv-store"; +export * from "./chatStore/AzureCosmosNoSqlChatStore.js"; export * from "./docStore/AzureCosmosNoSqlDocumentStore.js"; export { PostgresDocumentStore } from "./docStore/PostgresDocumentStore.js"; export { SimpleDocumentStore } from "./docStore/SimpleDocumentStore.js"; -- GitLab