From dca02f7277bfa65e052263e02faa094742376548 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Tue, 16 Apr 2024 11:01:26 +0800
Subject: [PATCH] refactor: VectorStoreIndex: use TransformerComponent to calc
 embeddings  (#721)

---
 README.md                                     |  2 +-
 examples/multimodal/load.ts                   |  2 +
 .../fixtures/embeddings/OpenAIEmbedding.ts    |  1 +
 .../src/embeddings/MultiModalEmbedding.ts     | 39 +++++++++-
 packages/core/src/embeddings/types.ts         | 65 ++++++++++-------
 .../core/src/indices/vectorStore/index.ts     | 71 +++++--------------
 6 files changed, 99 insertions(+), 81 deletions(-)

diff --git a/README.md b/README.md
index 1d1035c5e..3c9601600 100644
--- a/README.md
+++ b/README.md
@@ -154,7 +154,7 @@ If you need any of those classes, you have to import them instead directly. Here
 import { PineconeVectorStore } from "@llamaindex/edge/storage/vectorStore/PineconeVectorStore";
 ```
 
-As the `PDFReader` is not with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs:
+As the `PDFReader` is not working with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs:
 
 ```typescript
 import { SimpleDirectoryReader } from "@llamaindex/edge/readers/SimpleDirectoryReader";
diff --git a/examples/multimodal/load.ts b/examples/multimodal/load.ts
index 3ed94e30b..15c845b8f 100644
--- a/examples/multimodal/load.ts
+++ b/examples/multimodal/load.ts
@@ -4,6 +4,7 @@ import {
   VectorStoreIndex,
   storageContextFromDefaults,
 } from "llamaindex";
+import { DocStoreStrategy } from "llamaindex/ingestion/strategies/index";
 
 import * as path from "path";
 
@@ -31,6 +32,7 @@ async function generateDatasource() {
     });
     await VectorStoreIndex.fromDocuments(documents, {
       storageContext,
+      docStoreStrategy: DocStoreStrategy.NONE,
     });
   });
   console.log(`Storage successfully generated in ${ms / 1000}s.`);
diff --git a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts
index bab896c70..eec0bdbed 100644
--- a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts
+++ b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts
@@ -28,6 +28,7 @@ export class OpenAIEmbedding implements BaseEmbedding {
   }
 
   async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
+    nodes.forEach((node) => (node.embedding = [0]));
     return nodes;
   }
 }
diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts
index e2c7c1434..e220eede0 100644
--- a/packages/core/src/embeddings/MultiModalEmbedding.ts
+++ b/packages/core/src/embeddings/MultiModalEmbedding.ts
@@ -1,5 +1,10 @@
-import type { ImageType } from "../Node.js";
-import { BaseEmbedding } from "./types.js";
+import {
+  MetadataMode,
+  splitNodesByType,
+  type BaseNode,
+  type ImageType,
+} from "../Node.js";
+import { BaseEmbedding, batchEmbeddings } from "./types.js";
 
 /*
  * Base class for Multi Modal embeddings.
@@ -8,9 +13,39 @@ import { BaseEmbedding } from "./types.js";
 export abstract class MultiModalEmbedding extends BaseEmbedding {
   abstract getImageEmbedding(images: ImageType): Promise<number[]>;
 
+  /**
+   * Optionally override this method to retrieve multiple image embeddings in a single request
+   * @param texts
+   */
   async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
     return Promise.all(
       images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)),
     );
   }
+
+  async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
+    const { imageNodes, textNodes } = splitNodesByType(nodes);
+
+    const embeddings = await batchEmbeddings(
+      textNodes.map((node) => node.getContent(MetadataMode.EMBED)),
+      this.getTextEmbeddings.bind(this),
+      this.embedBatchSize,
+      _options,
+    );
+    for (let i = 0; i < textNodes.length; i++) {
+      textNodes[i].embedding = embeddings[i];
+    }
+
+    const imageEmbeddings = await batchEmbeddings(
+      imageNodes.map((n) => n.image),
+      this.getImageEmbeddings.bind(this),
+      this.embedBatchSize,
+      _options,
+    );
+    for (let i = 0; i < imageNodes.length; i++) {
+      imageNodes[i].embedding = imageEmbeddings[i];
+    }
+
+    return nodes;
+  }
 }
diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts
index 9c5892bdb..67d06940a 100644
--- a/packages/core/src/embeddings/types.ts
+++ b/packages/core/src/embeddings/types.ts
@@ -5,6 +5,8 @@ import { SimilarityType, similarity } from "./utils.js";
 
 const DEFAULT_EMBED_BATCH_SIZE = 10;
 
+type EmbedFunc<T> = (values: T[]) => Promise<Array<number[]>>;
+
 export abstract class BaseEmbedding implements TransformComponent {
   embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
 
@@ -45,35 +47,18 @@ export abstract class BaseEmbedding implements TransformComponent {
       logProgress?: boolean;
     },
   ): Promise<Array<number[]>> {
-    const resultEmbeddings: Array<number[]> = [];
-    const chunkSize = this.embedBatchSize;
-
-    const queue: string[] = texts;
-
-    const curBatch: string[] = [];
-
-    for (let i = 0; i < queue.length; i++) {
-      curBatch.push(queue[i]);
-      if (i == queue.length - 1 || curBatch.length == chunkSize) {
-        const embeddings = await this.getTextEmbeddings(curBatch);
-
-        resultEmbeddings.push(...embeddings);
-
-        if (options?.logProgress) {
-          console.log(`getting embedding progress: ${i} / ${queue.length}`);
-        }
-
-        curBatch.length = 0;
-      }
-    }
-
-    return resultEmbeddings;
+    return await batchEmbeddings(
+      texts,
+      this.getTextEmbeddings.bind(this),
+      this.embedBatchSize,
+      options,
+    );
   }
 
   async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
     const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
 
-    const embeddings = await this.getTextEmbeddingsBatch(texts);
+    const embeddings = await this.getTextEmbeddingsBatch(texts, _options);
 
     for (let i = 0; i < nodes.length; i++) {
       nodes[i].embedding = embeddings[i];
@@ -82,3 +67,35 @@ export abstract class BaseEmbedding implements TransformComponent {
     return nodes;
   }
 }
+
+export async function batchEmbeddings<T>(
+  values: T[],
+  embedFunc: EmbedFunc<T>,
+  chunkSize: number,
+  options?: {
+    logProgress?: boolean;
+  },
+): Promise<Array<number[]>> {
+  const resultEmbeddings: Array<number[]> = [];
+
+  const queue: T[] = values;
+
+  const curBatch: T[] = [];
+
+  for (let i = 0; i < queue.length; i++) {
+    curBatch.push(queue[i]);
+    if (i == queue.length - 1 || curBatch.length == chunkSize) {
+      const embeddings = await embedFunc(curBatch);
+
+      resultEmbeddings.push(...embeddings);
+
+      if (options?.logProgress) {
+        console.log(`getting embedding progress: ${i} / ${queue.length}`);
+      }
+
+      curBatch.length = 0;
+    }
+  }
+
+  return resultEmbeddings;
+}
diff --git a/packages/core/src/indices/vectorStore/index.ts b/packages/core/src/indices/vectorStore/index.ts
index 5619f7cb8..c2818d3bd 100644
--- a/packages/core/src/indices/vectorStore/index.ts
+++ b/packages/core/src/indices/vectorStore/index.ts
@@ -4,12 +4,7 @@ import type {
   Metadata,
   NodeWithScore,
 } from "../../Node.js";
-import {
-  ImageNode,
-  MetadataMode,
-  ObjectType,
-  splitNodesByType,
-} from "../../Node.js";
+import { ImageNode, ObjectType, splitNodesByType } from "../../Node.js";
 import type { BaseRetriever, RetrieveParams } from "../../Retriever.js";
 import type { ServiceContext } from "../../ServiceContext.js";
 import {
@@ -179,14 +174,21 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
     nodes: BaseNode[],
     options?: { logProgress?: boolean },
   ): Promise<BaseNode[]> {
-    const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
-    const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, {
+    const { imageNodes, textNodes } = splitNodesByType(nodes);
+    if (imageNodes.length > 0) {
+      if (!this.imageEmbedModel) {
+        throw new Error(
+          "Cannot calculate image nodes embedding without 'imageEmbedModel' set",
+        );
+      }
+      await this.imageEmbedModel.transform(imageNodes, {
+        logProgress: options?.logProgress,
+      });
+    }
+    await this.embedModel.transform(textNodes, {
       logProgress: options?.logProgress,
     });
-    return nodes.map((node, i) => {
-      node.embedding = embeddings[i];
-      return node;
-    });
+    return nodes;
   }
 
   /**
@@ -324,25 +326,15 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
     if (!nodes || nodes.length === 0) {
       return;
     }
+    nodes = await this.getNodeEmbeddingResults(nodes, options);
     const { imageNodes, textNodes } = splitNodesByType(nodes);
     if (imageNodes.length > 0) {
       if (!this.imageVectorStore) {
         throw new Error("Cannot insert image nodes without image vector store");
       }
-      const imageNodesWithEmbedding = await this.getImageNodeEmbeddingResults(
-        imageNodes,
-        options,
-      );
-      await this.insertNodesToStore(
-        this.imageVectorStore,
-        imageNodesWithEmbedding,
-      );
+      await this.insertNodesToStore(this.imageVectorStore, imageNodes);
     }
-    const embeddingResults = await this.getNodeEmbeddingResults(
-      textNodes,
-      options,
-    );
-    await this.insertNodesToStore(this.vectorStore, embeddingResults);
+    await this.insertNodesToStore(this.vectorStore, textNodes);
     await this.indexStore.addIndexStruct(this.indexStruct);
   }
 
@@ -378,35 +370,6 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
       await this.indexStore.addIndexStruct(this.indexStruct);
     }
   }
-
-  /**
-   * Calculates the embeddings for the given image nodes.
-   *
-   * @param nodes - An array of ImageNode objects representing the nodes for which embeddings are to be calculated.
-   * @param {Object} [options] - An optional object containing additional parameters.
-   *   @param {boolean} [options.logProgress] - A boolean indicating whether to log progress to the console (useful for debugging).
-   */
-  async getImageNodeEmbeddingResults(
-    nodes: ImageNode[],
-    options?: { logProgress?: boolean },
-  ): Promise<ImageNode[]> {
-    if (!this.imageEmbedModel) {
-      return [];
-    }
-
-    const nodesWithEmbeddings: ImageNode[] = [];
-
-    for (let i = 0; i < nodes.length; ++i) {
-      const node = nodes[i];
-      if (options?.logProgress) {
-        console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
-      }
-      node.embedding = await this.imageEmbedModel.getImageEmbedding(node.image);
-      nodesWithEmbeddings.push(node);
-    }
-
-    return nodesWithEmbeddings;
-  }
 }
 
 /**
-- 
GitLab