From 768318647079fca9bec5cd1aef1ee10a007fc1fb Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Tue, 21 Nov 2023 17:10:39 +0700 Subject: [PATCH] feat: added MultiModelVectorStoreIndex --- packages/core/src/Node.ts | 15 +-- packages/core/src/embeddings/ClipEmbedding.ts | 3 +- .../src/embeddings/MultiModalEmbedding.ts | 2 +- packages/core/src/embeddings/utils.ts | 3 +- .../multiModal/MultiModalVectorStoreIndex.ts | 94 +++++++++++++++++++ .../indices/vectorStore/VectorStoreIndex.ts | 24 +++-- 6 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index d60f358e4..157f3b7cd 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -229,13 +229,16 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> { } } -// export class ImageNode extends TextNode { -// image: string = ""; +export type ImageType = string | Blob | URL; -// getType(): ObjectType { -// return ObjectType.IMAGE; -// } -// } +export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> { + image?: ImageType; // base64 encoded image string + textEmbedding?: number[]; // Assuming text embedding is an array of numbers + + static getType(): string { + return ObjectType.IMAGE; + } +} export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> { indexId: string = ""; diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts index b75b4b879..989914dc5 100644 --- a/packages/core/src/embeddings/ClipEmbedding.ts +++ b/packages/core/src/embeddings/ClipEmbedding.ts @@ -1,5 +1,6 @@ +import { ImageType } from "../Node"; import { MultiModalEmbedding } from "./MultiModalEmbedding"; -import { ImageType, readImage } from "./utils"; +import { readImage } from "./utils"; export enum ClipEmbeddingModelType { XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32", diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts index c86ba0721..43bb854a4 100644 --- a/packages/core/src/embeddings/MultiModalEmbedding.ts +++ b/packages/core/src/embeddings/MultiModalEmbedding.ts @@ -1,5 +1,5 @@ +import { ImageType } from "../Node"; import { BaseEmbedding } from "./types"; -import { ImageType } from "./utils"; /* * Base class for Multi Modal embeddings. diff --git a/packages/core/src/embeddings/utils.ts b/packages/core/src/embeddings/utils.ts index cd192c3d4..cfdacf087 100644 --- a/packages/core/src/embeddings/utils.ts +++ b/packages/core/src/embeddings/utils.ts @@ -1,4 +1,5 @@ import _ from "lodash"; +import { ImageType } from "../Node"; import { DEFAULT_SIMILARITY_TOP_K } from "../constants"; import { VectorStoreQueryMode } from "../storage"; import { SimilarityType } from "./types"; @@ -183,6 +184,7 @@ export function getTopKMMREmbeddings( return [resultSimilarities, resultIds]; } + export async function readImage(input: ImageType) { const { RawImage } = await import("@xenova/transformers"); if (input instanceof Blob) { @@ -193,4 +195,3 @@ export async function readImage(input: ImageType) { throw new Error(`Unsupported input type: ${typeof input}`); } } -export type ImageType = string | Blob | URL; diff --git a/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts b/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts new file mode 100644 index 000000000..0f6a20b49 --- /dev/null +++ b/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts @@ -0,0 +1,94 @@ +import _ from "lodash"; +import { BaseNode, ImageNode, MetadataMode, TextNode } from "../../Node"; +import { ClipEmbedding, MultiModalEmbedding } from "../../embeddings"; +import { VectorStore } from "../../storage"; +import { VectorStoreIndex } from "../vectorStore"; +import { VectorIndexConstructorProps } from "../vectorStore/VectorStoreIndex"; + +export interface MultiModalVectorIndexConstructorProps + extends VectorIndexConstructorProps { + imageVectorStore: VectorStore; + imageEmbedModel?: MultiModalEmbedding; +} + +export class MultiModalVectorStoreIndex extends VectorStoreIndex { + imageVectorStore: VectorStore; + imageEmbedModel: MultiModalEmbedding; + + constructor(init: MultiModalVectorIndexConstructorProps) { + super(init); + this.imageVectorStore = init.imageVectorStore; + this.imageEmbedModel = init.imageEmbedModel ?? new ClipEmbedding(); + } + + /** + * Get the embeddings for image nodes. + * @param nodes + * @param serviceContext + * @param logProgress log progress to console (useful for debugging) + * @returns + */ + async getImageNodeEmbeddingResults( + nodes: ImageNode[], + logProgress: boolean = false, + ) { + const isImageToText = nodes.every((node) => _.isString(node.text)); + if (isImageToText) { + // image nodes have a text, use the text embedding model + return VectorStoreIndex.getNodeEmbeddingResults( + nodes, + this.serviceContext, + logProgress, + ); + } + + const nodesWithEmbeddings: ImageNode[] = []; + + for (let i = 0; i < nodes.length; ++i) { + const node = nodes[i]; + if (logProgress) { + console.log(`getting embedding for node ${i}/${nodes.length}`); + } + node.embedding = await this.imageEmbedModel.getImageEmbedding( + node.getContent(MetadataMode.EMBED), + ); + nodesWithEmbeddings.push(node); + } + + return nodesWithEmbeddings; + } + + private splitNodes(nodes: BaseNode[]): { + imageNodes: ImageNode[]; + textNodes: TextNode[]; + } { + let imageNodes: ImageNode[] = []; + let textNodes: TextNode[] = []; + + for (let node of nodes) { + if (node instanceof ImageNode) { + imageNodes.push(node); + } + if (node instanceof TextNode) { + textNodes.push(node); + } + } + return { + imageNodes, + textNodes, + }; + } + + async insertNodes(nodes: BaseNode[]): Promise<void> { + if (!nodes || nodes.length === 0) { + return; + } + const { imageNodes, textNodes } = this.splitNodes(nodes); + + super.insertNodes(textNodes); + + const imageNodesWithEmbedding = + await this.getImageNodeEmbeddingResults(imageNodes); + super.insertNodesToStore(this.imageVectorStore, imageNodesWithEmbedding); + } +} diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index fe1c0d95d..a2c099910 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -39,7 +39,7 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { export class VectorStoreIndex extends BaseIndex<IndexDict> { vectorStore: VectorStore; - private constructor(init: VectorIndexConstructorProps) { + protected constructor(init: VectorIndexConstructorProps) { super(init); this.vectorStore = init.vectorStore; } @@ -259,15 +259,13 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ); } - async insertNodes(nodes: BaseNode[]): Promise<void> { - const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults( - nodes, - this.serviceContext, - ); - - const newIds = await this.vectorStore.add(embeddingResults); + async insertNodesToStore( + vectorStore: VectorStore, + nodes: BaseNode[], + ): Promise<void> { + const newIds = await vectorStore.add(nodes); - if (!this.vectorStore.storesText) { + if (!vectorStore.storesText) { for (let i = 0; i < nodes.length; ++i) { this.indexStruct.addNode(nodes[i], newIds[i]); this.docStore.addDocuments([nodes[i]], true); @@ -284,6 +282,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { await this.storageContext.indexStore.addIndexStruct(this.indexStruct); } + async insertNodes(nodes: BaseNode[]): Promise<void> { + const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults( + nodes, + this.serviceContext, + ); + await this.insertNodesToStore(this.vectorStore, embeddingResults); + } + async deleteRefDoc( refDocId: string, deleteFromDocStore: boolean = true, -- GitLab