From dd054137bf16a043d9581bc70432ff1129640516 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Thu, 8 Feb 2024 18:59:56 +0700
Subject: [PATCH] feat: use batching in vector store index (#524)

Co-authored-by: Alex Yang <himself65@outlook.com>
Co-authored-by: Emanuel Ferreira <contatoferreirads@gmail.com>
---
 .changeset/tame-ways-applaud.md               |  5 +++
 .../src/embeddings/HuggingFaceEmbedding.ts    |  2 +-
 .../core/src/embeddings/OpenAIEmbedding.ts    | 31 ++++++++++++++-----
 packages/core/src/embeddings/types.ts         |  4 +--
 .../indices/vectorStore/VectorStoreIndex.ts   | 22 +++++--------
 5 files changed, 39 insertions(+), 25 deletions(-)
 create mode 100644 .changeset/tame-ways-applaud.md

diff --git a/.changeset/tame-ways-applaud.md b/.changeset/tame-ways-applaud.md
new file mode 100644
index 000000000..510b1b121
--- /dev/null
+++ b/.changeset/tame-ways-applaud.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+feat: use batching in vector store index
diff --git a/packages/core/src/embeddings/HuggingFaceEmbedding.ts b/packages/core/src/embeddings/HuggingFaceEmbedding.ts
index f27eb07ad..810bc14fb 100644
--- a/packages/core/src/embeddings/HuggingFaceEmbedding.ts
+++ b/packages/core/src/embeddings/HuggingFaceEmbedding.ts
@@ -36,7 +36,7 @@ export class HuggingFaceEmbedding extends BaseEmbedding {
     return this.extractor;
   }
 
-  async getTextEmbedding(text: string): Promise<number[]> {
+  override async getTextEmbedding(text: string): Promise<number[]> {
     const extractor = await this.getExtractor();
     const output = await extractor(text, { pooling: "mean", normalize: true });
     return Array.from(output.data);
diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts
index 996a5b4c8..321c3da0c 100644
--- a/packages/core/src/embeddings/OpenAIEmbedding.ts
+++ b/packages/core/src/embeddings/OpenAIEmbedding.ts
@@ -102,28 +102,43 @@ export class OpenAIEmbedding extends BaseEmbedding {
     }
   }
 
-  private async getOpenAIEmbedding(
-    input: string | string[],
-  ): Promise<number[]> {
+  /**
+   * Get embeddings for a batch of texts
+   * @param texts
+   * @param options
+   */
+  private async getOpenAIEmbedding(input: string[]): Promise<number[][]> {
     const { data } = await this.session.openai.embeddings.create({
       model: this.model,
       dimensions: this.dimensions, // only sent to OpenAI if set by user
       input,
     });
 
-    return data[0].embedding;
+    return data.map((d) => d.embedding);
   }
 
+  /**
+   * Get embeddings for a batch of texts
+   * @param texts
+   */
   async getTextEmbeddings(texts: string[]): Promise<number[][]> {
-    const embeddings = await this.getOpenAIEmbedding(texts);
-    return Array(embeddings);
+    return await this.getOpenAIEmbedding(texts);
   }
 
+  /**
+   * Get embeddings for a single text
+   * @param texts
+   */
   async getTextEmbedding(text: string): Promise<number[]> {
-    return this.getOpenAIEmbedding(text);
+    return (await this.getOpenAIEmbedding([text]))[0];
   }
 
+  /**
+   * Get embeddings for a query
+   * @param texts
+   * @param options
+   */
   async getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getOpenAIEmbedding(query);
+    return (await this.getOpenAIEmbedding([query]))[0];
   }
 }
diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts
index f4a0a5675..1058e3ead 100644
--- a/packages/core/src/embeddings/types.ts
+++ b/packages/core/src/embeddings/types.ts
@@ -19,7 +19,7 @@ export abstract class BaseEmbedding implements TransformComponent {
   abstract getQueryEmbedding(query: string): Promise<number[]>;
 
   /**
-   * Get embeddings for a batch of texts
+   * Optionally override this method to retrieve multiple embeddings in a single request
    * @param texts
    */
   async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
@@ -59,7 +59,7 @@ export abstract class BaseEmbedding implements TransformComponent {
         resultEmbeddings.push(...embeddings);
 
         if (options?.logProgress) {
-          console.log(`number[] progress: ${i} / ${queue.length}`);
+          console.log(`getting embedding progress: ${i} / ${queue.length}`);
         }
 
         curBatch.length = 0;
diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
index bc1b6c230..b416126e7 100644
--- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
+++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
@@ -166,20 +166,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
     nodes: BaseNode[],
     options?: { logProgress?: boolean },
   ): Promise<BaseNode[]> {
-    const nodesWithEmbeddings: BaseNode[] = [];
-
-    for (let i = 0; i < nodes.length; ++i) {
-      const node = nodes[i];
-      if (options?.logProgress) {
-        console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
-      }
-      node.embedding = await this.embedModel.getTextEmbedding(
-        node.getContent(MetadataMode.EMBED),
-      );
-      nodesWithEmbeddings.push(node);
-    }
-
-    return nodesWithEmbeddings;
+    const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
+    const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, {
+      logProgress: options?.logProgress,
+    });
+    return nodes.map((node, i) => {
+      node.embedding = embeddings[i];
+      return node;
+    });
   }
 
   /**
-- 
GitLab