Skip to content
Snippets Groups Projects
Unverified Commit b6c15005 authored by Emanuel Ferreira's avatar Emanuel Ferreira Committed by GitHub
Browse files

feat(embedding): add batch embed size (#407)

parent d06a85bd
No related branches found
No related tags found
No related merge requests found
---
"llamaindex": patch
---
feat(embedBatchSize): add batching for embeddings
......@@ -59,7 +59,9 @@ export class OpenAIEmbedding extends BaseEmbedding {
this.model = init?.model ?? "text-embedding-ada-002";
this.dimensions = init?.dimensions; // if no dimensions provided, will be undefined/not sent to OpenAI
this.embedBatchSize = init?.embedBatchSize ?? 10;
this.maxRetries = init?.maxRetries ?? 10;
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.additionalSessionOptions = init?.additionalSessionOptions;
......@@ -100,7 +102,9 @@ export class OpenAIEmbedding extends BaseEmbedding {
}
}
private async getOpenAIEmbedding(input: string) {
private async getOpenAIEmbedding(
input: string | string[],
): Promise<number[]> {
const { data } = await this.session.openai.embeddings.create({
model: this.model,
dimensions: this.dimensions, // only sent to OpenAI if set by user
......@@ -110,6 +114,11 @@ export class OpenAIEmbedding extends BaseEmbedding {
return data[0].embedding;
}
async getTextEmbeddings(texts: string[]): Promise<number[][]> {
const embeddings = await this.getOpenAIEmbedding(texts);
return Array(embeddings);
}
async getTextEmbedding(text: string): Promise<number[]> {
return this.getOpenAIEmbedding(text);
}
......
......@@ -2,7 +2,11 @@ import { BaseNode, MetadataMode } from "../Node";
import { TransformComponent } from "../ingestion";
import { SimilarityType, similarity } from "./utils";
const DEFAULT_EMBED_BATCH_SIZE = 10;
export abstract class BaseEmbedding implements TransformComponent {
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
similarity(
embedding1: number[],
embedding2: number[],
......@@ -14,12 +18,66 @@ export abstract class BaseEmbedding implements TransformComponent {
abstract getTextEmbedding(text: string): Promise<number[]>;
abstract getQueryEmbedding(query: string): Promise<number[]>;
/**
* Get embeddings for a batch of texts
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
const embeddings: number[][] = [];
for (const text of texts) {
const embedding = await this.getTextEmbedding(text);
embeddings.push(embedding);
}
return embeddings;
}
/**
* Get embeddings for a batch of texts
* @param texts
* @param options
*/
async getTextEmbeddingsBatch(
texts: string[],
options?: {
logProgress?: boolean;
},
): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = [];
const chunkSize = this.embedBatchSize;
const queue: string[] = texts;
const curBatch: string[] = [];
for (let i = 0; i < queue.length; i++) {
curBatch.push(queue[i]);
if (i == queue.length - 1 || curBatch.length == chunkSize) {
const embeddings = await this.getTextEmbeddings(curBatch);
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`number[] progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
}
}
return resultEmbeddings;
}
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
for (const node of nodes) {
node.embedding = await this.getTextEmbedding(
node.getContent(MetadataMode.EMBED),
);
const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
const embeddings = await this.getTextEmbeddingsBatch(texts);
for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i];
}
return nodes;
}
}
......@@ -168,14 +168,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
): Promise<BaseNode[]> {
const nodesWithEmbeddings: BaseNode[] = [];
const embeddingResults = await this.embedModel.getTextEmbeddingsBatch(
nodes.map((node) => node.getContent(MetadataMode.EMBED)),
options,
);
for (let i = 0; i < nodes.length; ++i) {
const node = nodes[i];
if (options?.logProgress) {
console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
}
node.embedding = await this.embedModel.getTextEmbedding(
node.getContent(MetadataMode.EMBED),
);
node.embedding = embeddingResults[i];
nodesWithEmbeddings.push(node);
}
......
import { similarity, SimilarityType } from "../embeddings";
import { OpenAIEmbedding, similarity, SimilarityType } from "../embeddings";
import { mockEmbeddingModel } from "./utility/mockOpenAI";
// Mock the OpenAI getOpenAISession function during testing
jest.mock("../llm/open_ai", () => {
return {
getOpenAISession: jest.fn().mockImplementation(() => null),
};
});
describe("similarity", () => {
test("throws error on mismatched lengths", () => {
......@@ -42,3 +50,32 @@ describe("similarity", () => {
);
});
});
describe("[OpenAIEmbedding]", () => {
let embedModel: OpenAIEmbedding;
beforeAll(() => {
let openAIEmbedding = new OpenAIEmbedding();
mockEmbeddingModel(openAIEmbedding);
embedModel = openAIEmbedding;
});
test("getTextEmbedding", async () => {
const embedding = await embedModel.getTextEmbedding("hello");
expect(embedding.length).toEqual(6);
});
test("getTextEmbeddings", async () => {
const texts = ["hello", "world"];
const embeddings = await embedModel.getTextEmbeddings(texts);
expect(embeddings.length).toEqual(1);
});
test("getTextEmbeddingsBatch", async () => {
const texts = ["hello", "world"];
const embeddings = await embedModel.getTextEmbeddingsBatch(texts);
expect(embeddings.length).toEqual(1);
});
});
......@@ -90,6 +90,11 @@ export function mockEmbeddingModel(embedModel: OpenAIEmbedding) {
resolve([1, 0, 0, 0, 0, 0]);
});
});
jest.spyOn(embedModel, "getTextEmbeddings").mockImplementation(async (x) => {
return new Promise((resolve) => {
resolve([[1, 0, 0, 0, 0, 0]]);
});
});
jest.spyOn(embedModel, "getQueryEmbedding").mockImplementation(async (x) => {
return new Promise((resolve) => {
resolve([0, 1, 0, 0, 0, 0]);
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment