diff --git a/.changeset/tame-ways-applaud.md b/.changeset/tame-ways-applaud.md new file mode 100644 index 0000000000000000000000000000000000000000..510b1b121e3fcf0b0b78bea4c151f3d10276e69d --- /dev/null +++ b/.changeset/tame-ways-applaud.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: use batching in vector store index diff --git a/packages/core/src/embeddings/HuggingFaceEmbedding.ts b/packages/core/src/embeddings/HuggingFaceEmbedding.ts index f27eb07ad95c8473036c2120a20ab59c3b500c54..810bc14fbcedb66b0b039704534f417bfa531884 100644 --- a/packages/core/src/embeddings/HuggingFaceEmbedding.ts +++ b/packages/core/src/embeddings/HuggingFaceEmbedding.ts @@ -36,7 +36,7 @@ export class HuggingFaceEmbedding extends BaseEmbedding { return this.extractor; } - async getTextEmbedding(text: string): Promise<number[]> { + override async getTextEmbedding(text: string): Promise<number[]> { const extractor = await this.getExtractor(); const output = await extractor(text, { pooling: "mean", normalize: true }); return Array.from(output.data); diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts index 996a5b4c8e3cc2799f64aa3eaf10fe450dbfce70..321c3da0c93aeeeb08226c52b5794660b4fa545d 100644 --- a/packages/core/src/embeddings/OpenAIEmbedding.ts +++ b/packages/core/src/embeddings/OpenAIEmbedding.ts @@ -102,28 +102,43 @@ export class OpenAIEmbedding extends BaseEmbedding { } } - private async getOpenAIEmbedding( - input: string | string[], - ): Promise<number[]> { + /** + * Get embeddings for a batch of texts + * @param texts + * @param options + */ + private async getOpenAIEmbedding(input: string[]): Promise<number[][]> { const { data } = await this.session.openai.embeddings.create({ model: this.model, dimensions: this.dimensions, // only sent to OpenAI if set by user input, }); - return data[0].embedding; + return data.map((d) => d.embedding); } + /** + * Get embeddings for a batch of texts + * @param texts + */ async getTextEmbeddings(texts: string[]): Promise<number[][]> { - const embeddings = await this.getOpenAIEmbedding(texts); - return Array(embeddings); + return await this.getOpenAIEmbedding(texts); } + /** + * Get embeddings for a single text + * @param texts + */ async getTextEmbedding(text: string): Promise<number[]> { - return this.getOpenAIEmbedding(text); + return (await this.getOpenAIEmbedding([text]))[0]; } + /** + * Get embeddings for a query + * @param texts + * @param options + */ async getQueryEmbedding(query: string): Promise<number[]> { - return this.getOpenAIEmbedding(query); + return (await this.getOpenAIEmbedding([query]))[0]; } } diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts index f4a0a5675f8ff580cb0d15f891f04b192c1e1b46..1058e3ead02758c2289820349b52f1c35a921d7f 100644 --- a/packages/core/src/embeddings/types.ts +++ b/packages/core/src/embeddings/types.ts @@ -19,7 +19,7 @@ export abstract class BaseEmbedding implements TransformComponent { abstract getQueryEmbedding(query: string): Promise<number[]>; /** - * Get embeddings for a batch of texts + * Optionally override this method to retrieve multiple embeddings in a single request * @param texts */ async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> { @@ -59,7 +59,7 @@ export abstract class BaseEmbedding implements TransformComponent { resultEmbeddings.push(...embeddings); if (options?.logProgress) { - console.log(`number[] progress: ${i} / ${queue.length}`); + console.log(`getting embedding progress: ${i} / ${queue.length}`); } curBatch.length = 0; diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index bc1b6c230d800cfc586781eee4172b783474be5b..b416126e73ab04e0dbcff65f4d629391da8a27bf 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -166,20 +166,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { nodes: BaseNode[], options?: { logProgress?: boolean }, ): Promise<BaseNode[]> { - const nodesWithEmbeddings: BaseNode[] = []; - - for (let i = 0; i < nodes.length; ++i) { - const node = nodes[i]; - if (options?.logProgress) { - console.log(`Getting embedding for node ${i + 1}/${nodes.length}`); - } - node.embedding = await this.embedModel.getTextEmbedding( - node.getContent(MetadataMode.EMBED), - ); - nodesWithEmbeddings.push(node); - } - - return nodesWithEmbeddings; + const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); + const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, { + logProgress: options?.logProgress, + }); + return nodes.map((node, i) => { + node.embedding = embeddings[i]; + return node; + }); } /**