Skip to content
Snippets Groups Projects
Unverified Commit dd054137 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

feat: use batching in vector store index (#524)

parent cf3b7571
Branches
Tags
No related merge requests found
---
"llamaindex": patch
---
feat: use batching in vector store index
......@@ -36,7 +36,7 @@ export class HuggingFaceEmbedding extends BaseEmbedding {
return this.extractor;
}
async getTextEmbedding(text: string): Promise<number[]> {
override async getTextEmbedding(text: string): Promise<number[]> {
const extractor = await this.getExtractor();
const output = await extractor(text, { pooling: "mean", normalize: true });
return Array.from(output.data);
......
......@@ -102,28 +102,43 @@ export class OpenAIEmbedding extends BaseEmbedding {
}
}
private async getOpenAIEmbedding(
input: string | string[],
): Promise<number[]> {
/**
* Get embeddings for a batch of texts
* @param texts
* @param options
*/
private async getOpenAIEmbedding(input: string[]): Promise<number[][]> {
const { data } = await this.session.openai.embeddings.create({
model: this.model,
dimensions: this.dimensions, // only sent to OpenAI if set by user
input,
});
return data[0].embedding;
return data.map((d) => d.embedding);
}
/**
* Get embeddings for a batch of texts
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<number[][]> {
const embeddings = await this.getOpenAIEmbedding(texts);
return Array(embeddings);
return await this.getOpenAIEmbedding(texts);
}
/**
* Get embeddings for a single text
* @param texts
*/
async getTextEmbedding(text: string): Promise<number[]> {
return this.getOpenAIEmbedding(text);
return (await this.getOpenAIEmbedding([text]))[0];
}
/**
* Get embeddings for a query
* @param texts
* @param options
*/
async getQueryEmbedding(query: string): Promise<number[]> {
return this.getOpenAIEmbedding(query);
return (await this.getOpenAIEmbedding([query]))[0];
}
}
......@@ -19,7 +19,7 @@ export abstract class BaseEmbedding implements TransformComponent {
abstract getQueryEmbedding(query: string): Promise<number[]>;
/**
* Get embeddings for a batch of texts
* Optionally override this method to retrieve multiple embeddings in a single request
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
......@@ -59,7 +59,7 @@ export abstract class BaseEmbedding implements TransformComponent {
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`number[] progress: ${i} / ${queue.length}`);
console.log(`getting embedding progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
......
......@@ -166,20 +166,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
nodes: BaseNode[],
options?: { logProgress?: boolean },
): Promise<BaseNode[]> {
const nodesWithEmbeddings: BaseNode[] = [];
for (let i = 0; i < nodes.length; ++i) {
const node = nodes[i];
if (options?.logProgress) {
console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
}
node.embedding = await this.embedModel.getTextEmbedding(
node.getContent(MetadataMode.EMBED),
);
nodesWithEmbeddings.push(node);
}
return nodesWithEmbeddings;
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.embedModel.getTextEmbeddingsBatch(texts, {
logProgress: options?.logProgress,
});
return nodes.map((node, i) => {
node.embedding = embeddings[i];
return node;
});
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment