From 7def68fb37bd5d1ed60490da7dd408bb4bdc3deb Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Mon, 8 Jan 2024 11:46:58 +0700
Subject: [PATCH] feat: added local embedding

---
 examples/huggingface.ts                       | 43 +++++++++++++++++
 .../src/embeddings/HuggingFaceEmbedding.ts    | 48 +++++++++++++++++++
 packages/core/src/embeddings/index.ts         |  1 +
 3 files changed, 92 insertions(+)
 create mode 100644 examples/huggingface.ts
 create mode 100644 packages/core/src/embeddings/HuggingFaceEmbedding.ts

diff --git a/examples/huggingface.ts b/examples/huggingface.ts
new file mode 100644
index 000000000..1d02b43ab
--- /dev/null
+++ b/examples/huggingface.ts
@@ -0,0 +1,43 @@
+import fs from "node:fs/promises";
+
+import {
+  Document,
+  HuggingFaceEmbedding,
+  HuggingFaceEmbeddingModelType,
+  VectorStoreIndex,
+  serviceContextFromDefaults,
+} from "llamaindex";
+
+async function main() {
+  // Load essay from abramov.txt in Node
+  const path = "node_modules/llamaindex/examples/abramov.txt";
+
+  const essay = await fs.readFile(path, "utf-8");
+
+  // Create Document object with essay
+  const document = new Document({ text: essay, id_: path });
+
+  // Use Local embedding from HuggingFace
+  const embedModel = new HuggingFaceEmbedding({
+    modelType: HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2,
+  });
+  const serviceContext = serviceContextFromDefaults({
+    embedModel,
+  });
+
+  // Split text and create embeddings. Store them in a VectorStoreIndex
+  const index = await VectorStoreIndex.fromDocuments([document], {
+    serviceContext,
+  });
+
+  // Query the index
+  const queryEngine = index.asQueryEngine();
+  const response = await queryEngine.query(
+    "What did the author do in college?",
+  );
+
+  // Output response
+  console.log(response.toString());
+}
+
+main().catch(console.error);
diff --git a/packages/core/src/embeddings/HuggingFaceEmbedding.ts b/packages/core/src/embeddings/HuggingFaceEmbedding.ts
new file mode 100644
index 000000000..13ee9139e
--- /dev/null
+++ b/packages/core/src/embeddings/HuggingFaceEmbedding.ts
@@ -0,0 +1,48 @@
+import { BaseEmbedding } from "./types";
+
+export enum HuggingFaceEmbeddingModelType {
+  XENOVA_ALL_MINILM_L6_V2 = "Xenova/all-MiniLM-L6-v2",
+  XENOVA_ALL_MPNET_BASE_V2 = "Xenova/all-mpnet-base-v2",
+}
+
+/**
+ * Uses feature extraction from '@xenova/transformers' to generate embeddings.
+ * Per default the model [XENOVA_ALL_MINILM_L6_V2](https://huggingface.co/Xenova/all-MiniLM-L6-v2) is used.
+ *
+ * Can be changed by setting the `modelType` parameter in the constructor, e.g.:
+ * ```
+ * new HuggingFaceEmbedding({
+ *     modelType: HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2,
+ * });
+ * ```
+ *
+ * @extends BaseEmbedding
+ */
+export class HuggingFaceEmbedding extends BaseEmbedding {
+  modelType: string = HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2;
+
+  private extractor: any;
+
+  constructor(init?: Partial<HuggingFaceEmbedding>) {
+    super();
+    Object.assign(this, init);
+  }
+
+  async getExtractor() {
+    if (!this.extractor) {
+      const { pipeline } = await import("@xenova/transformers");
+      this.extractor = await pipeline("feature-extraction", this.modelType);
+    }
+    return this.extractor;
+  }
+
+  async getTextEmbedding(text: string): Promise<number[]> {
+    const extractor = await this.getExtractor();
+    const output = await extractor(text, { pooling: "mean", normalize: true });
+    return output.data;
+  }
+
+  async getQueryEmbedding(query: string): Promise<number[]> {
+    return this.getTextEmbedding(query);
+  }
+}
diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts
index 092e5fb86..32d6535bd 100644
--- a/packages/core/src/embeddings/index.ts
+++ b/packages/core/src/embeddings/index.ts
@@ -1,4 +1,5 @@
 export * from "./ClipEmbedding";
+export * from "./HuggingFaceEmbedding";
 export * from "./MistralAIEmbedding";
 export * from "./MultiModalEmbedding";
 export * from "./OpenAIEmbedding";
-- 
GitLab