From 768318647079fca9bec5cd1aef1ee10a007fc1fb Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Tue, 21 Nov 2023 17:10:39 +0700
Subject: [PATCH] feat: added MultiModelVectorStoreIndex

---
 packages/core/src/Node.ts                     | 15 +--
 packages/core/src/embeddings/ClipEmbedding.ts |  3 +-
 .../src/embeddings/MultiModalEmbedding.ts     |  2 +-
 packages/core/src/embeddings/utils.ts         |  3 +-
 .../multiModal/MultiModalVectorStoreIndex.ts  | 94 +++++++++++++++++++
 .../indices/vectorStore/VectorStoreIndex.ts   | 24 +++--
 6 files changed, 123 insertions(+), 18 deletions(-)
 create mode 100644 packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts

diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts
index d60f358e4..157f3b7cd 100644
--- a/packages/core/src/Node.ts
+++ b/packages/core/src/Node.ts
@@ -229,13 +229,16 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> {
   }
 }
 
-// export class ImageNode extends TextNode {
-//   image: string = "";
+export type ImageType = string | Blob | URL;
 
-//   getType(): ObjectType {
-//     return ObjectType.IMAGE;
-//   }
-// }
+export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
+  image?: ImageType; // base64 encoded image string
+  textEmbedding?: number[]; // Assuming text embedding is an array of numbers
+
+  static getType(): string {
+    return ObjectType.IMAGE;
+  }
+}
 
 export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> {
   indexId: string = "";
diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts
index b75b4b879..989914dc5 100644
--- a/packages/core/src/embeddings/ClipEmbedding.ts
+++ b/packages/core/src/embeddings/ClipEmbedding.ts
@@ -1,5 +1,6 @@
+import { ImageType } from "../Node";
 import { MultiModalEmbedding } from "./MultiModalEmbedding";
-import { ImageType, readImage } from "./utils";
+import { readImage } from "./utils";
 
 export enum ClipEmbeddingModelType {
   XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32",
diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts
index c86ba0721..43bb854a4 100644
--- a/packages/core/src/embeddings/MultiModalEmbedding.ts
+++ b/packages/core/src/embeddings/MultiModalEmbedding.ts
@@ -1,5 +1,5 @@
+import { ImageType } from "../Node";
 import { BaseEmbedding } from "./types";
-import { ImageType } from "./utils";
 
 /*
  * Base class for Multi Modal embeddings.
diff --git a/packages/core/src/embeddings/utils.ts b/packages/core/src/embeddings/utils.ts
index cd192c3d4..cfdacf087 100644
--- a/packages/core/src/embeddings/utils.ts
+++ b/packages/core/src/embeddings/utils.ts
@@ -1,4 +1,5 @@
 import _ from "lodash";
+import { ImageType } from "../Node";
 import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
 import { VectorStoreQueryMode } from "../storage";
 import { SimilarityType } from "./types";
@@ -183,6 +184,7 @@ export function getTopKMMREmbeddings(
 
   return [resultSimilarities, resultIds];
 }
+
 export async function readImage(input: ImageType) {
   const { RawImage } = await import("@xenova/transformers");
   if (input instanceof Blob) {
@@ -193,4 +195,3 @@ export async function readImage(input: ImageType) {
     throw new Error(`Unsupported input type: ${typeof input}`);
   }
 }
-export type ImageType = string | Blob | URL;
diff --git a/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts b/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts
new file mode 100644
index 000000000..0f6a20b49
--- /dev/null
+++ b/packages/core/src/indices/multiModal/MultiModalVectorStoreIndex.ts
@@ -0,0 +1,94 @@
+import _ from "lodash";
+import { BaseNode, ImageNode, MetadataMode, TextNode } from "../../Node";
+import { ClipEmbedding, MultiModalEmbedding } from "../../embeddings";
+import { VectorStore } from "../../storage";
+import { VectorStoreIndex } from "../vectorStore";
+import { VectorIndexConstructorProps } from "../vectorStore/VectorStoreIndex";
+
+export interface MultiModalVectorIndexConstructorProps
+  extends VectorIndexConstructorProps {
+  imageVectorStore: VectorStore;
+  imageEmbedModel?: MultiModalEmbedding;
+}
+
+export class MultiModalVectorStoreIndex extends VectorStoreIndex {
+  imageVectorStore: VectorStore;
+  imageEmbedModel: MultiModalEmbedding;
+
+  constructor(init: MultiModalVectorIndexConstructorProps) {
+    super(init);
+    this.imageVectorStore = init.imageVectorStore;
+    this.imageEmbedModel = init.imageEmbedModel ?? new ClipEmbedding();
+  }
+
+  /**
+   * Get the embeddings for image nodes.
+   * @param nodes
+   * @param serviceContext
+   * @param logProgress log progress to console (useful for debugging)
+   * @returns
+   */
+  async getImageNodeEmbeddingResults(
+    nodes: ImageNode[],
+    logProgress: boolean = false,
+  ) {
+    const isImageToText = nodes.every((node) => _.isString(node.text));
+    if (isImageToText) {
+      // image nodes have a text, use the text embedding model
+      return VectorStoreIndex.getNodeEmbeddingResults(
+        nodes,
+        this.serviceContext,
+        logProgress,
+      );
+    }
+
+    const nodesWithEmbeddings: ImageNode[] = [];
+
+    for (let i = 0; i < nodes.length; ++i) {
+      const node = nodes[i];
+      if (logProgress) {
+        console.log(`getting embedding for node ${i}/${nodes.length}`);
+      }
+      node.embedding = await this.imageEmbedModel.getImageEmbedding(
+        node.getContent(MetadataMode.EMBED),
+      );
+      nodesWithEmbeddings.push(node);
+    }
+
+    return nodesWithEmbeddings;
+  }
+
+  private splitNodes(nodes: BaseNode[]): {
+    imageNodes: ImageNode[];
+    textNodes: TextNode[];
+  } {
+    let imageNodes: ImageNode[] = [];
+    let textNodes: TextNode[] = [];
+
+    for (let node of nodes) {
+      if (node instanceof ImageNode) {
+        imageNodes.push(node);
+      }
+      if (node instanceof TextNode) {
+        textNodes.push(node);
+      }
+    }
+    return {
+      imageNodes,
+      textNodes,
+    };
+  }
+
+  async insertNodes(nodes: BaseNode[]): Promise<void> {
+    if (!nodes || nodes.length === 0) {
+      return;
+    }
+    const { imageNodes, textNodes } = this.splitNodes(nodes);
+
+    super.insertNodes(textNodes);
+
+    const imageNodesWithEmbedding =
+      await this.getImageNodeEmbeddingResults(imageNodes);
+    super.insertNodesToStore(this.imageVectorStore, imageNodesWithEmbedding);
+  }
+}
diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
index fe1c0d95d..a2c099910 100644
--- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
+++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
@@ -39,7 +39,7 @@ export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> {
 export class VectorStoreIndex extends BaseIndex<IndexDict> {
   vectorStore: VectorStore;
 
-  private constructor(init: VectorIndexConstructorProps) {
+  protected constructor(init: VectorIndexConstructorProps) {
     super(init);
     this.vectorStore = init.vectorStore;
   }
@@ -259,15 +259,13 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
     );
   }
 
-  async insertNodes(nodes: BaseNode[]): Promise<void> {
-    const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults(
-      nodes,
-      this.serviceContext,
-    );
-
-    const newIds = await this.vectorStore.add(embeddingResults);
+  async insertNodesToStore(
+    vectorStore: VectorStore,
+    nodes: BaseNode[],
+  ): Promise<void> {
+    const newIds = await vectorStore.add(nodes);
 
-    if (!this.vectorStore.storesText) {
+    if (!vectorStore.storesText) {
       for (let i = 0; i < nodes.length; ++i) {
         this.indexStruct.addNode(nodes[i], newIds[i]);
         this.docStore.addDocuments([nodes[i]], true);
@@ -284,6 +282,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
     await this.storageContext.indexStore.addIndexStruct(this.indexStruct);
   }
 
+  async insertNodes(nodes: BaseNode[]): Promise<void> {
+    const embeddingResults = await VectorStoreIndex.getNodeEmbeddingResults(
+      nodes,
+      this.serviceContext,
+    );
+    await this.insertNodesToStore(this.vectorStore, embeddingResults);
+  }
+
   async deleteRefDoc(
     refDocId: string,
     deleteFromDocStore: boolean = true,
-- 
GitLab