From dca02f7277bfa65e052263e02faa094742376548 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Tue, 16 Apr 2024 11:01:26 +0800 Subject: [PATCH] refactor: VectorStoreIndex: use TransformerComponent to calc embeddings (#721) --- README.md | 2 +- examples/multimodal/load.ts | 2 + .../fixtures/embeddings/OpenAIEmbedding.ts | 1 + .../src/embeddings/MultiModalEmbedding.ts | 39 +++++++++- packages/core/src/embeddings/types.ts | 65 ++++++++++------- .../core/src/indices/vectorStore/index.ts | 71 +++++-------------- 6 files changed, 99 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 1d1035c5e..3c9601600 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ If you need any of those classes, you have to import them instead directly. Here import { PineconeVectorStore } from "@llamaindex/edge/storage/vectorStore/PineconeVectorStore"; ``` -As the `PDFReader` is not with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs: +As the `PDFReader` is not working with the Edge runtime, here's how to use the `SimpleDirectoryReader` with the `LlamaParseReader` to load PDFs: ```typescript import { SimpleDirectoryReader } from "@llamaindex/edge/readers/SimpleDirectoryReader"; diff --git a/examples/multimodal/load.ts b/examples/multimodal/load.ts index 3ed94e30b..15c845b8f 100644 --- a/examples/multimodal/load.ts +++ b/examples/multimodal/load.ts @@ -4,6 +4,7 @@ import { VectorStoreIndex, storageContextFromDefaults, } from "llamaindex"; +import { DocStoreStrategy } from "llamaindex/ingestion/strategies/index"; import * as path from "path"; @@ -31,6 +32,7 @@ async function generateDatasource() { }); await VectorStoreIndex.fromDocuments(documents, { storageContext, + docStoreStrategy: DocStoreStrategy.NONE, }); }); console.log(`Storage successfully generated in ${ms / 1000}s.`); diff --git a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts index bab896c70..eec0bdbed 100644 --- a/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts +++ b/packages/core/e2e/fixtures/embeddings/OpenAIEmbedding.ts @@ -28,6 +28,7 @@ export class OpenAIEmbedding implements BaseEmbedding { } async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { + nodes.forEach((node) => (node.embedding = [0])); return nodes; } } diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts index e2c7c1434..e220eede0 100644 --- a/packages/core/src/embeddings/MultiModalEmbedding.ts +++ b/packages/core/src/embeddings/MultiModalEmbedding.ts @@ -1,5 +1,10 @@ -import type { ImageType } from "../Node.js"; -import { BaseEmbedding } from "./types.js"; +import { + MetadataMode, + splitNodesByType, + type BaseNode, + type ImageType, +} from "../Node.js"; +import { BaseEmbedding, batchEmbeddings } from "./types.js"; /* * Base class for Multi Modal embeddings. @@ -8,9 +13,39 @@ import { BaseEmbedding } from "./types.js"; export abstract class MultiModalEmbedding extends BaseEmbedding { abstract getImageEmbedding(images: ImageType): Promise<number[]>; + /** + * Optionally override this method to retrieve multiple image embeddings in a single request + * @param texts + */ async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { return Promise.all( images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), ); } + + async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { + const { imageNodes, textNodes } = splitNodesByType(nodes); + + const embeddings = await batchEmbeddings( + textNodes.map((node) => node.getContent(MetadataMode.EMBED)), + this.getTextEmbeddings.bind(this), + this.embedBatchSize, + _options, + ); + for (let i = 0; i < textNodes.length; i++) { + textNodes[i].embedding = embeddings[i]; + } + + const imageEmbeddings = await batchEmbeddings( + imageNodes.map((n) => n.image), + this.getImageEmbeddings.bind(this), + this.embedBatchSize, + _options, + ); + for (let i = 0; i < imageNodes.length; i++) { + imageNodes[i].embedding = imageEmbeddings[i]; + } + + return nodes; + } } diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts index 9c5892bdb..67d06940a 100644 --- a/packages/core/src/embeddings/types.ts +++ b/packages/core/src/embeddings/types.ts @@ -5,6 +5,8 @@ import { SimilarityType, similarity } from "./utils.js"; const DEFAULT_EMBED_BATCH_SIZE = 10; +type EmbedFunc<T> = (values: T[]) => Promise<Array<number[]>>; + export abstract class BaseEmbedding implements TransformComponent { embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; @@ -45,35 +47,18 @@ export abstract class BaseEmbedding implements TransformComponent { logProgress?: boolean; }, ): Promise<Array<number[]>> { - const resultEmbeddings: Array<number[]> = []; - const chunkSize = this.embedBatchSize; - - const queue: string[] = texts; - - const curBatch: string[] = []; - - for (let i = 0; i < queue.length; i++) { - curBatch.push(queue[i]); - if (i == queue.length - 1 || curBatch.length == chunkSize) { - const embeddings = await this.getTextEmbeddings(curBatch); - - resultEmbeddings.push(...embeddings); - - if (options?.logProgress) { - console.log(`getting embedding progress: ${i} / ${queue.length}`); - } - - curBatch.length = 0; - } - } - - return resultEmbeddings; + return await batchEmbeddings( + texts, + this.getTextEmbeddings.bind(this), + this.embedBatchSize, + options, + ); } async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); - const embeddings = await this.getTextEmbeddingsBatch(texts); + const embeddings = await this.getTextEmbeddingsBatch(texts, _options); for (let i = 0; i < nodes.length; i++) { nodes[i].embedding = embeddings[i]; @@ -82,3 +67,35 @@ export abstract class BaseEmbedding implements TransformComponent { return nodes; } } + +export async function batchEmbeddings<T>( + values: T[], + embedFunc: EmbedFunc<T>, + chunkSize: number, + options?: { + logProgress?: boolean; + }, +): Promise<Array<number[]>> { + const resultEmbeddings: Array<number[]> = []; + + const queue: T[] = values; + + const curBatch: T[] = []; + + for (let i = 0; i < queue.length; i++) { + curBatch.push(queue[i]); + if (i == queue.length - 1 || curBatch.length == chunkSize) { + const embeddings = await embedFunc(curBatch); + + resultEmbeddings.push(...embeddings); + + if (options?.logProgress) { + console.log(`getting embedding progress: ${i} / ${queue.length}`); + } + + curBatch.length = 0; + } + } + + return resultEmbeddings; +} diff --git a/packages/core/src/indices/vectorStore/index.ts b/packages/core/src/indices/vectorStore/index.ts index 5619f7cb8..c2818d3bd 100644 --- a/packages/core/src/indices/vectorStore/index.ts +++ b/packages/core/src/indices/vectorStore/index.ts @@ -4,12 +4,7 @@ import type { Metadata, NodeWithScore, } from "../../Node.js"; -import { - ImageNode, - MetadataMode, - ObjectType, - splitNodesByType, -} from "../../Node.js"; +import { ImageNode, ObjectType, splitNodesByType } from "../../Node.js"; import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { @@ -179,14 +174,21 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { nodes: BaseNode[], options?: { logProgress?: boolean }, ): Promise<BaseNode[]> { - const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); - const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, { + const { imageNodes, textNodes } = splitNodesByType(nodes); + if (imageNodes.length > 0) { + if (!this.imageEmbedModel) { + throw new Error( + "Cannot calculate image nodes embedding without 'imageEmbedModel' set", + ); + } + await this.imageEmbedModel.transform(imageNodes, { + logProgress: options?.logProgress, + }); + } + await this.embedModel.transform(textNodes, { logProgress: options?.logProgress, }); - return nodes.map((node, i) => { - node.embedding = embeddings[i]; - return node; - }); + return nodes; } /** @@ -324,25 +326,15 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { if (!nodes || nodes.length === 0) { return; } + nodes = await this.getNodeEmbeddingResults(nodes, options); const { imageNodes, textNodes } = splitNodesByType(nodes); if (imageNodes.length > 0) { if (!this.imageVectorStore) { throw new Error("Cannot insert image nodes without image vector store"); } - const imageNodesWithEmbedding = await this.getImageNodeEmbeddingResults( - imageNodes, - options, - ); - await this.insertNodesToStore( - this.imageVectorStore, - imageNodesWithEmbedding, - ); + await this.insertNodesToStore(this.imageVectorStore, imageNodes); } - const embeddingResults = await this.getNodeEmbeddingResults( - textNodes, - options, - ); - await this.insertNodesToStore(this.vectorStore, embeddingResults); + await this.insertNodesToStore(this.vectorStore, textNodes); await this.indexStore.addIndexStruct(this.indexStruct); } @@ -378,35 +370,6 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { await this.indexStore.addIndexStruct(this.indexStruct); } } - - /** - * Calculates the embeddings for the given image nodes. - * - * @param nodes - An array of ImageNode objects representing the nodes for which embeddings are to be calculated. - * @param {Object} [options] - An optional object containing additional parameters. - * @param {boolean} [options.logProgress] - A boolean indicating whether to log progress to the console (useful for debugging). - */ - async getImageNodeEmbeddingResults( - nodes: ImageNode[], - options?: { logProgress?: boolean }, - ): Promise<ImageNode[]> { - if (!this.imageEmbedModel) { - return []; - } - - const nodesWithEmbeddings: ImageNode[] = []; - - for (let i = 0; i < nodes.length; ++i) { - const node = nodes[i]; - if (options?.logProgress) { - console.log(`Getting embedding for node ${i + 1}/${nodes.length}`); - } - node.embedding = await this.imageEmbedModel.getImageEmbedding(node.image); - nodesWithEmbeddings.push(node); - } - - return nodesWithEmbeddings; - } } /** -- GitLab