From 7c6b3117bf1c711d49951bc5aa1f05f129698554 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Mon, 8 Jan 2024 14:17:35 +0700
Subject: [PATCH] feat: use local embedding for sentence window example and log
 progress

---
 examples/sentenceWindow.ts                    | 10 ++-
 .../indices/vectorStore/VectorStoreIndex.ts   | 75 ++++++++++++-------
 2 files changed, 57 insertions(+), 28 deletions(-)

diff --git a/examples/sentenceWindow.ts b/examples/sentenceWindow.ts
index 66b4fd9ad..d4c9cf182 100644
--- a/examples/sentenceWindow.ts
+++ b/examples/sentenceWindow.ts
@@ -1,5 +1,6 @@
 import {
   Document,
+  HuggingFaceEmbedding,
   MetadataReplacementPostProcessor,
   SentenceWindowNodeParser,
   VectorStoreIndex,
@@ -11,23 +12,28 @@ async function main() {
   const document = new Document({ text: essay, id_: "essay" });
 
   // create service context with sentence window parser
+  // and local embedding from HuggingFace
   const nodeParser = new SentenceWindowNodeParser({
     windowSize: 3,
     windowMetadataKey: "window",
     originalTextMetadataKey: "original_text",
   });
-  const serviceContext = serviceContextFromDefaults({ nodeParser });
+  const embedModel = new HuggingFaceEmbedding();
+  const serviceContext = serviceContextFromDefaults({ nodeParser, embedModel });
 
   // Split text and create embeddings. Store them in a VectorStoreIndex
   const index = await VectorStoreIndex.fromDocuments([document], {
     serviceContext,
+    logProgress: true,
   });
 
   // Query the index
   const queryEngine = index.asQueryEngine({
     nodePostprocessors: [new MetadataReplacementPostProcessor("window")],
   });
-  const response = await queryEngine.query("What mistakes did they make?");
+  const response = await queryEngine.query(
+    "What did the author do in college?",
+  );
 
   // Output response
   console.log(response.toString());
diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
index 59d177b34..ea52780bb 100644
--- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
+++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts
@@ -43,6 +43,7 @@ export interface VectorIndexOptions extends IndexStructOptions {
   storageContext?: StorageContext;
   imageVectorStore?: VectorStore;
   vectorStore?: VectorStore;
+  logProgress?: boolean;
 }
 
 export interface VectorIndexConstructorProps extends BaseIndexInit<IndexDict> {
@@ -110,7 +111,9 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
 
     if (options.nodes) {
       // If nodes are passed in, then we need to update the index
-      await index.buildIndexFromNodes(options.nodes);
+      await index.buildIndexFromNodes(options.nodes, {
+        logProgress: options.logProgress,
+      });
     }
     return index;
   }
@@ -147,23 +150,26 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
   }
 
   /**
-   * Get the embeddings for nodes.
-   * @param nodes
-   * @param logProgress log progress to console (useful for debugging)
-   * @returns
+   * Calculates the embeddings for the given nodes.
+   *
+   * @param nodes - An array of BaseNode objects representing the nodes for which embeddings are to be calculated.
+   * @param {Object} [options] - An optional object containing additional parameters.
+   *   @param {boolean} [options.logProgress] - A boolean indicating whether to log progress to the console (useful for debugging).
    */
-  async getNodeEmbeddingResults(nodes: BaseNode[], logProgress = false) {
+  async getNodeEmbeddingResults(
+    nodes: BaseNode[],
+    options?: { logProgress?: boolean },
+  ): Promise<BaseNode[]> {
     const nodesWithEmbeddings: BaseNode[] = [];
 
     for (let i = 0; i < nodes.length; ++i) {
       const node = nodes[i];
-      if (logProgress) {
-        console.log(`getting embedding for node ${i}/${nodes.length}`);
+      if (options?.logProgress) {
+        console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
       }
-      const embedding = await this.embedModel.getTextEmbedding(
+      node.embedding = await this.embedModel.getTextEmbedding(
         node.getContent(MetadataMode.EMBED),
       );
-      node.embedding = embedding;
       nodesWithEmbeddings.push(node);
     }
 
@@ -175,7 +181,10 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
    * @param nodes
    * @returns
    */
-  async buildIndexFromNodes(nodes: BaseNode[]) {
+  async buildIndexFromNodes(
+    nodes: BaseNode[],
+    options?: { logProgress?: boolean },
+  ) {
     // Check if the index already has nodes with the same hash
     const newNodes = nodes.filter((node) =>
       Object.entries(this.indexStruct!.nodesDict).reduce(
@@ -189,7 +198,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
       ),
     );
 
-    await this.insertNodes(newNodes);
+    await this.insertNodes(newNodes, options);
   }
 
   /**
@@ -211,8 +220,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
       docStore.setDocumentHash(doc.id_, doc.hash);
     }
 
+    if (args.logProgress) {
+      console.log("Using node parser on documents...");
+    }
     args.nodes =
       args.serviceContext.nodeParser.getNodesFromDocuments(documents);
+    if (args.logProgress) {
+      console.log("Finished parsing documents.");
+    }
     return await this.init(args);
   }
 
@@ -280,12 +295,15 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
         const nodeWithoutEmbedding = nodes[i].clone();
         nodeWithoutEmbedding.embedding = undefined;
         this.indexStruct.addNode(nodeWithoutEmbedding, newIds[i]);
-        this.docStore.addDocuments([nodeWithoutEmbedding], true);
+        await this.docStore.addDocuments([nodeWithoutEmbedding], true);
       }
     }
   }
 
-  async insertNodes(nodes: BaseNode[]): Promise<void> {
+  async insertNodes(
+    nodes: BaseNode[],
+    options?: { logProgress?: boolean },
+  ): Promise<void> {
     if (!nodes || nodes.length === 0) {
       return;
     }
@@ -294,14 +312,19 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
       if (!this.imageVectorStore) {
         throw new Error("Cannot insert image nodes without image vector store");
       }
-      const imageNodesWithEmbedding =
-        await this.getImageNodeEmbeddingResults(imageNodes);
+      const imageNodesWithEmbedding = await this.getImageNodeEmbeddingResults(
+        imageNodes,
+        options,
+      );
       await this.insertNodesToStore(
         this.imageVectorStore,
         imageNodesWithEmbedding,
       );
     }
-    const embeddingResults = await this.getNodeEmbeddingResults(textNodes);
+    const embeddingResults = await this.getNodeEmbeddingResults(
+      textNodes,
+      options,
+    );
     await this.insertNodesToStore(this.vectorStore, embeddingResults);
     await this.indexStore.addIndexStruct(this.indexStruct);
   }
@@ -340,16 +363,16 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
   }
 
   /**
-   * Get the embeddings for image nodes.
-   * @param nodes
-   * @param serviceContext
-   * @param logProgress log progress to console (useful for debugging)
-   * @returns
+   * Calculates the embeddings for the given image nodes.
+   *
+   * @param nodes - An array of ImageNode objects representing the nodes for which embeddings are to be calculated.
+   * @param {Object} [options] - An optional object containing additional parameters.
+   *   @param {boolean} [options.logProgress] - A boolean indicating whether to log progress to the console (useful for debugging).
    */
   async getImageNodeEmbeddingResults(
     nodes: ImageNode[],
-    logProgress: boolean = false,
-  ): Promise<BaseNode[]> {
+    options?: { logProgress?: boolean },
+  ): Promise<ImageNode[]> {
     if (!this.imageEmbedModel) {
       return [];
     }
@@ -358,8 +381,8 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
 
     for (let i = 0; i < nodes.length; ++i) {
       const node = nodes[i];
-      if (logProgress) {
-        console.log(`getting embedding for node ${i}/${nodes.length}`);
+      if (options?.logProgress) {
+        console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
       }
       node.embedding = await this.imageEmbedModel.getImageEmbedding(node.image);
       nodesWithEmbeddings.push(node);
-- 
GitLab