From 7def68fb37bd5d1ed60490da7dd408bb4bdc3deb Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Mon, 8 Jan 2024 11:46:58 +0700 Subject: [PATCH] feat: added local embedding --- examples/huggingface.ts | 43 +++++++++++++++++ .../src/embeddings/HuggingFaceEmbedding.ts | 48 +++++++++++++++++++ packages/core/src/embeddings/index.ts | 1 + 3 files changed, 92 insertions(+) create mode 100644 examples/huggingface.ts create mode 100644 packages/core/src/embeddings/HuggingFaceEmbedding.ts diff --git a/examples/huggingface.ts b/examples/huggingface.ts new file mode 100644 index 000000000..1d02b43ab --- /dev/null +++ b/examples/huggingface.ts @@ -0,0 +1,43 @@ +import fs from "node:fs/promises"; + +import { + Document, + HuggingFaceEmbedding, + HuggingFaceEmbeddingModelType, + VectorStoreIndex, + serviceContextFromDefaults, +} from "llamaindex"; + +async function main() { + // Load essay from abramov.txt in Node + const path = "node_modules/llamaindex/examples/abramov.txt"; + + const essay = await fs.readFile(path, "utf-8"); + + // Create Document object with essay + const document = new Document({ text: essay, id_: path }); + + // Use Local embedding from HuggingFace + const embedModel = new HuggingFaceEmbedding({ + modelType: HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2, + }); + const serviceContext = serviceContextFromDefaults({ + embedModel, + }); + + // Split text and create embeddings. Store them in a VectorStoreIndex + const index = await VectorStoreIndex.fromDocuments([document], { + serviceContext, + }); + + // Query the index + const queryEngine = index.asQueryEngine(); + const response = await queryEngine.query( + "What did the author do in college?", + ); + + // Output response + console.log(response.toString()); +} + +main().catch(console.error); diff --git a/packages/core/src/embeddings/HuggingFaceEmbedding.ts b/packages/core/src/embeddings/HuggingFaceEmbedding.ts new file mode 100644 index 000000000..13ee9139e --- /dev/null +++ b/packages/core/src/embeddings/HuggingFaceEmbedding.ts @@ -0,0 +1,48 @@ +import { BaseEmbedding } from "./types"; + +export enum HuggingFaceEmbeddingModelType { + XENOVA_ALL_MINILM_L6_V2 = "Xenova/all-MiniLM-L6-v2", + XENOVA_ALL_MPNET_BASE_V2 = "Xenova/all-mpnet-base-v2", +} + +/** + * Uses feature extraction from '@xenova/transformers' to generate embeddings. + * Per default the model [XENOVA_ALL_MINILM_L6_V2](https://huggingface.co/Xenova/all-MiniLM-L6-v2) is used. + * + * Can be changed by setting the `modelType` parameter in the constructor, e.g.: + * ``` + * new HuggingFaceEmbedding({ + * modelType: HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2, + * }); + * ``` + * + * @extends BaseEmbedding + */ +export class HuggingFaceEmbedding extends BaseEmbedding { + modelType: string = HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2; + + private extractor: any; + + constructor(init?: Partial<HuggingFaceEmbedding>) { + super(); + Object.assign(this, init); + } + + async getExtractor() { + if (!this.extractor) { + const { pipeline } = await import("@xenova/transformers"); + this.extractor = await pipeline("feature-extraction", this.modelType); + } + return this.extractor; + } + + async getTextEmbedding(text: string): Promise<number[]> { + const extractor = await this.getExtractor(); + const output = await extractor(text, { pooling: "mean", normalize: true }); + return output.data; + } + + async getQueryEmbedding(query: string): Promise<number[]> { + return this.getTextEmbedding(query); + } +} diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts index 092e5fb86..32d6535bd 100644 --- a/packages/core/src/embeddings/index.ts +++ b/packages/core/src/embeddings/index.ts @@ -1,4 +1,5 @@ export * from "./ClipEmbedding"; +export * from "./HuggingFaceEmbedding"; export * from "./MistralAIEmbedding"; export * from "./MultiModalEmbedding"; export * from "./OpenAIEmbedding"; -- GitLab