diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index d60f358e4bc364e52d9a15b8ee27adaac0535e12..157f3b7cdf814698bcd0c6c03164cae8b3820647 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -229,13 +229,16 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> { } } -// export class ImageNode extends TextNode { -// image: string = ""; +export type ImageType = string | Blob | URL; -// getType(): ObjectType { -// return ObjectType.IMAGE; -// } -// } +export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> { + image?: ImageType; // base64 encoded image string + textEmbedding?: number[]; // Assuming text embedding is an array of numbers + + static getType(): string { + return ObjectType.IMAGE; + } +} export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> { indexId: string = ""; diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts index b75b4b879c502b9f255b719985f3c0392e49f027..989914dc5ed2c957eed1cb461a24e7d675fc97bf 100644 --- a/packages/core/src/embeddings/ClipEmbedding.ts +++ b/packages/core/src/embeddings/ClipEmbedding.ts @@ -1,5 +1,6 @@ +import { ImageType } from "../Node"; import { MultiModalEmbedding } from "./MultiModalEmbedding"; -import { ImageType, readImage } from "./utils"; +import { readImage } from "./utils"; export enum ClipEmbeddingModelType { XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32", diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts index c86ba07211ba54b157b043252324270549889502..43bb854a4c92a3af321d223026442bfb9082fd01 100644 --- a/packages/core/src/embeddings/MultiModalEmbedding.ts +++ b/packages/core/src/embeddings/MultiModalEmbedding.ts @@ -1,5 +1,5 @@ +import { ImageType } from "../Node"; import { BaseEmbedding } from "./types"; -import { ImageType } from "./utils"; /* * Base class for Multi Modal embeddings. diff --git a/packages/core/src/embeddings/utils.ts b/packages/core/src/embeddings/utils.ts index cd192c3d4faa08a5b4f81c6a796de1e89233ca45..cfdacf0871096efae8c231dcfb29b7a424fedeaf 100644 --- a/packages/core/src/embeddings/utils.ts +++ b/packages/core/src/embeddings/utils.ts @@ -1,4 +1,5 @@ import _ from "lodash"; +import { ImageType } from "../Node"; import { DEFAULT_SIMILARITY_TOP_K } from "../constants"; import { VectorStoreQueryMode } from "../storage"; import { SimilarityType } from "./types"; @@ -183,6 +184,7 @@ export function getTopKMMREmbeddings( return [resultSimilarities, resultIds]; } + export async function readImage(input: ImageType) { const { RawImage } = await import("@xenova/transformers"); if (input instanceof Blob) { @@ -193,4 +195,3 @@ export async function readImage(input: ImageType) { throw new Error(`Unsupported input type: ${typeof input}`); } } -export type ImageType = string | Blob | URL; diff --git a/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts b/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts new file mode 100644 index 0000000000000000000000000000000000000000..0f6a20b4946fcf20bb2cd6dc6904134d20dc6f2d --- /dev/null +++ b/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts @@ -0,0 +1,94 @@ +import _ from "lodash"; +import { BaseNode, ImageNode, MetadataMode, TextNode } from "../../Node"; +import { ClipEmbedding, MultiModalEmbedding } from "../../embeddings"; +import { VectorStore } from "../../storage"; +import { VectorStoreIndex } from "../vectorStore"; +import { VectorIndexConstructorProps } from "../vectorStore/VectorStoreIndex"; + +export interface MultiModalVectorIndexConstructorProps + extends VectorIndexConstructorProps { + imageVectorStore: VectorStore; + imageEmbedModel?: MultiModalEmbedding; +} + +export class MultiModalVectorStoreIndex extends VectorStoreIndex { + imageVectorStore: VectorStore; + imageEmbedModel: MultiModalEmbedding; + + constructor(init: MultiModalVectorIndexConstructorProps) { + super(init); + this.imageVectorStore = init.imageVectorStore; + this.imageEmbedModel = init.imageEmbedModel ?? new ClipEmbedding(); + } + + /** + * Get the embeddings for image nodes. + * @param nodes + * @param serviceContext + * @param logProgress log progress to console (useful for debugging) + * @returns + */ + async getImageNodeEmbeddingResults( + nodes: ImageNode[], + logProgress: boolean = false, + ) { + const isImageToText = nodes.every((node) => _.isString(node.text)); + if (isImageToText) { + // image nodes have a text, use the text embedding model + return VectorStoreIndex.getNodeEmbeddingResults( + nodes, + this.serviceContext, + logProgress, + ); + } + + const nodesWithEmbeddings: ImageNode[] = []; + + for (let i = 0; i < nodes.length; ++i) { + const node = nodes[i]; + if (logProgress) { + console.log(`getting embedding for node ${i}/${nodes.length}`); + } + node.embedding = await this.imageEmbedModel.getImageEmbedding( + node.getContent(MetadataMode.EMBED), + ); + nodesWithEmbeddings.push(node); + } + + return nodesWithEmbeddings; + } + + private splitNodes(nodes: BaseNode[]): { + imageNodes: ImageNode[]; + textNodes: TextNode[]; + } { + let imageNodes: ImageNode[] = []; + let textNodes: TextNode[] = []; + + for (let node of nodes) { + if (node instanceof ImageNode) { + imageNodes.push(node); + } + if (node instanceof TextNode) { + textNodes.push(node); + } + } + return { + imageNodes, + textNodes, + }; + } + + async insertNodes(nodes: BaseNode[]): Promise<void> { + if (!nodes || nodes.length === 0) { + return; + } + const { imageNodes, textNodes } = this.splitNodes(nodes); + + super.insertNodes(textNodes); + + const imageNodesWithEmbedding = + await this.getImageNodeEmbeddingResults(imageNodes); + super.insertNodesToStore(this.imageVectorStore, imageNodesWithEmbedding); + } +} diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index fe1c0d95d6ab759d14d3ad55379fd71635039e04..a2c0999106b34cf08f5461efcd59f2fd4914c3e4 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -39,7 +39,7 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> { export class VectorStoreIndex extends BaseIndex<IndexDict> { vectorStore: VectorStore; - private constructor(init: VectorIndexConstructorProps) { + protected constructor(init: VectorIndexConstructorProps) { super(init); this.vectorStore = init.vectorStore; } @@ -259,15 +259,13 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ); } - async insertNodes(nodes: BaseNode[]): Promise<void> { - const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults( - nodes, - this.serviceContext, - ); - - const newIds = await this.vectorStore.add(embeddingResults); + async insertNodesToStore( + vectorStore: VectorStore, + nodes: BaseNode[], + ): Promise<void> { + const newIds = await vectorStore.add(nodes); - if (!this.vectorStore.storesText) { + if (!vectorStore.storesText) { for (let i = 0; i < nodes.length; ++i) { this.indexStruct.addNode(nodes[i], newIds[i]); this.docStore.addDocuments([nodes[i]], true); @@ -284,6 +282,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { await this.storageContext.indexStore.addIndexStruct(this.indexStruct); } + async insertNodes(nodes: BaseNode[]): Promise<void> { + const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults( + nodes, + this.serviceContext, + ); + await this.insertNodesToStore(this.vectorStore, embeddingResults); + } + async deleteRefDoc( refDocId: string, deleteFromDocStore: boolean = true,