From dd054137bf16a043d9581bc70432ff1129640516 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Thu, 8 Feb 2024 18:59:56 +0700 Subject: [PATCH] feat: use batching in vector store index (#524) Co-authored-by: Alex Yang <himself65@outlook.com> Co-authored-by: Emanuel Ferreira <contatoferreirads@gmail.com> --- .changeset/tame-ways-applaud.md | 5 +++ .../src/embeddings/HuggingFaceEmbedding.ts | 2 +- .../core/src/embeddings/OpenAIEmbedding.ts | 31 ++++++++++++++----- packages/core/src/embeddings/types.ts | 4 +-- .../indices/vectorStore/VectorStoreIndex.ts | 22 +++++-------- 5 files changed, 39 insertions(+), 25 deletions(-) create mode 100644 .changeset/tame-ways-applaud.md diff --git a/.changeset/tame-ways-applaud.md b/.changeset/tame-ways-applaud.md new file mode 100644 index 000000000..510b1b121 --- /dev/null +++ b/.changeset/tame-ways-applaud.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: use batching in vector store index diff --git a/packages/core/src/embeddings/HuggingFaceEmbedding.ts b/packages/core/src/embeddings/HuggingFaceEmbedding.ts index f27eb07ad..810bc14fb 100644 --- a/packages/core/src/embeddings/HuggingFaceEmbedding.ts +++ b/packages/core/src/embeddings/HuggingFaceEmbedding.ts @@ -36,7 +36,7 @@ export class HuggingFaceEmbedding extends BaseEmbedding { return this.extractor; } - async getTextEmbedding(text: string): Promise<number[]> { + override async getTextEmbedding(text: string): Promise<number[]> { const extractor = await this.getExtractor(); const output = await extractor(text, { pooling: "mean", normalize: true }); return Array.from(output.data); diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts index 996a5b4c8..321c3da0c 100644 --- a/packages/core/src/embeddings/OpenAIEmbedding.ts +++ b/packages/core/src/embeddings/OpenAIEmbedding.ts @@ -102,28 +102,43 @@ export class OpenAIEmbedding extends BaseEmbedding { } } - private async getOpenAIEmbedding( - input: string | string[], - ): Promise<number[]> { + /** + * Get embeddings for a batch of texts + * @param texts + * @param options + */ + private async getOpenAIEmbedding(input: string[]): Promise<number[][]> { const { data } = await this.session.openai.embeddings.create({ model: this.model, dimensions: this.dimensions, // only sent to OpenAI if set by user input, }); - return data[0].embedding; + return data.map((d) => d.embedding); } + /** + * Get embeddings for a batch of texts + * @param texts + */ async getTextEmbeddings(texts: string[]): Promise<number[][]> { - const embeddings = await this.getOpenAIEmbedding(texts); - return Array(embeddings); + return await this.getOpenAIEmbedding(texts); } + /** + * Get embeddings for a single text + * @param texts + */ async getTextEmbedding(text: string): Promise<number[]> { - return this.getOpenAIEmbedding(text); + return (await this.getOpenAIEmbedding([text]))[0]; } + /** + * Get embeddings for a query + * @param texts + * @param options + */ async getQueryEmbedding(query: string): Promise<number[]> { - return this.getOpenAIEmbedding(query); + return (await this.getOpenAIEmbedding([query]))[0]; } } diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts index f4a0a5675..1058e3ead 100644 --- a/packages/core/src/embeddings/types.ts +++ b/packages/core/src/embeddings/types.ts @@ -19,7 +19,7 @@ export abstract class BaseEmbedding implements TransformComponent { abstract getQueryEmbedding(query: string): Promise<number[]>; /** - * Get embeddings for a batch of texts + * Optionally override this method to retrieve multiple embeddings in a single request * @param texts */ async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> { @@ -59,7 +59,7 @@ export abstract class BaseEmbedding implements TransformComponent { resultEmbeddings.push(...embeddings); if (options?.logProgress) { - console.log(`number[] progress: ${i} / ${queue.length}`); + console.log(`getting embedding progress: ${i} / ${queue.length}`); } curBatch.length = 0; diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index bc1b6c230..b416126e7 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -166,20 +166,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { nodes: BaseNode[], options?: { logProgress?: boolean }, ): Promise<BaseNode[]> { - const nodesWithEmbeddings: BaseNode[] = []; - - for (let i = 0; i < nodes.length; ++i) { - const node = nodes[i]; - if (options?.logProgress) { - console.log(`Getting embedding for node ${i + 1}/${nodes.length}`); - } - node.embedding = await this.embedModel.getTextEmbedding( - node.getContent(MetadataMode.EMBED), - ); - nodesWithEmbeddings.push(node); - } - - return nodesWithEmbeddings; + const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); + const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, { + logProgress: options?.logProgress, + }); + return nodes.map((node, i) => { + node.embedding = embeddings[i]; + return node; + }); } /** -- GitLab