Skip to content
Snippets Groups Projects
Unverified Commit b6c15005 authored by Emanuel Ferreira's avatar Emanuel Ferreira Committed by GitHub
Browse files

feat(embedding): add batch embed size (#407)

parent d06a85bd
No related branches found
No related tags found
No related merge requests found
---
"llamaindex": patch
---
feat(embedBatchSize): add batching for embeddings
...@@ -59,7 +59,9 @@ export class OpenAIEmbedding extends BaseEmbedding { ...@@ -59,7 +59,9 @@ export class OpenAIEmbedding extends BaseEmbedding {
this.model = init?.model ?? "text-embedding-ada-002"; this.model = init?.model ?? "text-embedding-ada-002";
this.dimensions = init?.dimensions; // if no dimensions provided, will be undefined/not sent to OpenAI this.dimensions = init?.dimensions; // if no dimensions provided, will be undefined/not sent to OpenAI
this.embedBatchSize = init?.embedBatchSize ?? 10;
this.maxRetries = init?.maxRetries ?? 10; this.maxRetries = init?.maxRetries ?? 10;
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.additionalSessionOptions = init?.additionalSessionOptions; this.additionalSessionOptions = init?.additionalSessionOptions;
...@@ -100,7 +102,9 @@ export class OpenAIEmbedding extends BaseEmbedding { ...@@ -100,7 +102,9 @@ export class OpenAIEmbedding extends BaseEmbedding {
} }
} }
private async getOpenAIEmbedding(input: string) { private async getOpenAIEmbedding(
input: string | string[],
): Promise<number[]> {
const { data } = await this.session.openai.embeddings.create({ const { data } = await this.session.openai.embeddings.create({
model: this.model, model: this.model,
dimensions: this.dimensions, // only sent to OpenAI if set by user dimensions: this.dimensions, // only sent to OpenAI if set by user
...@@ -110,6 +114,11 @@ export class OpenAIEmbedding extends BaseEmbedding { ...@@ -110,6 +114,11 @@ export class OpenAIEmbedding extends BaseEmbedding {
return data[0].embedding; return data[0].embedding;
} }
async getTextEmbeddings(texts: string[]): Promise<number[][]> {
const embeddings = await this.getOpenAIEmbedding(texts);
return Array(embeddings);
}
async getTextEmbedding(text: string): Promise<number[]> { async getTextEmbedding(text: string): Promise<number[]> {
return this.getOpenAIEmbedding(text); return this.getOpenAIEmbedding(text);
} }
......
...@@ -2,7 +2,11 @@ import { BaseNode, MetadataMode } from "../Node"; ...@@ -2,7 +2,11 @@ import { BaseNode, MetadataMode } from "../Node";
import { TransformComponent } from "../ingestion"; import { TransformComponent } from "../ingestion";
import { SimilarityType, similarity } from "./utils"; import { SimilarityType, similarity } from "./utils";
const DEFAULT_EMBED_BATCH_SIZE = 10;
export abstract class BaseEmbedding implements TransformComponent { export abstract class BaseEmbedding implements TransformComponent {
embedBatchSize = DEFAULT_EMBED_BATCH_SIZE;
similarity( similarity(
embedding1: number[], embedding1: number[],
embedding2: number[], embedding2: number[],
...@@ -14,12 +18,66 @@ export abstract class BaseEmbedding implements TransformComponent { ...@@ -14,12 +18,66 @@ export abstract class BaseEmbedding implements TransformComponent {
abstract getTextEmbedding(text: string): Promise<number[]>; abstract getTextEmbedding(text: string): Promise<number[]>;
abstract getQueryEmbedding(query: string): Promise<number[]>; abstract getQueryEmbedding(query: string): Promise<number[]>;
/**
* Get embeddings for a batch of texts
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<Array<number[]>> {
const embeddings: number[][] = [];
for (const text of texts) {
const embedding = await this.getTextEmbedding(text);
embeddings.push(embedding);
}
return embeddings;
}
/**
* Get embeddings for a batch of texts
* @param texts
* @param options
*/
async getTextEmbeddingsBatch(
texts: string[],
options?: {
logProgress?: boolean;
},
): Promise<Array<number[]>> {
const resultEmbeddings: Array<number[]> = [];
const chunkSize = this.embedBatchSize;
const queue: string[] = texts;
const curBatch: string[] = [];
for (let i = 0; i < queue.length; i++) {
curBatch.push(queue[i]);
if (i == queue.length - 1 || curBatch.length == chunkSize) {
const embeddings = await this.getTextEmbeddings(curBatch);
resultEmbeddings.push(...embeddings);
if (options?.logProgress) {
console.log(`number[] progress: ${i} / ${queue.length}`);
}
curBatch.length = 0;
}
}
return resultEmbeddings;
}
async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> {
for (const node of nodes) { const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED));
node.embedding = await this.getTextEmbedding(
node.getContent(MetadataMode.EMBED), const embeddings = await this.getTextEmbeddingsBatch(texts);
);
for (let i = 0; i < nodes.length; i++) {
nodes[i].embedding = embeddings[i];
} }
return nodes; return nodes;
} }
} }
...@@ -168,14 +168,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { ...@@ -168,14 +168,14 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
): Promise<BaseNode[]> { ): Promise<BaseNode[]> {
const nodesWithEmbeddings: BaseNode[] = []; const nodesWithEmbeddings: BaseNode[] = [];
const embeddingResults = await this.embedModel.getTextEmbeddingsBatch(
nodes.map((node) => node.getContent(MetadataMode.EMBED)),
options,
);
for (let i = 0; i < nodes.length; ++i) { for (let i = 0; i < nodes.length; ++i) {
const node = nodes[i]; const node = nodes[i];
if (options?.logProgress) { node.embedding = embeddingResults[i];
console.log(`Getting embedding for node ${i + 1}/${nodes.length}`);
}
node.embedding = await this.embedModel.getTextEmbedding(
node.getContent(MetadataMode.EMBED),
);
nodesWithEmbeddings.push(node); nodesWithEmbeddings.push(node);
} }
......
import { similarity, SimilarityType } from "../embeddings"; import { OpenAIEmbedding, similarity, SimilarityType } from "../embeddings";
import { mockEmbeddingModel } from "./utility/mockOpenAI";
// Mock the OpenAI getOpenAISession function during testing
jest.mock("../llm/open_ai", () => {
return {
getOpenAISession: jest.fn().mockImplementation(() => null),
};
});
describe("similarity", () => { describe("similarity", () => {
test("throws error on mismatched lengths", () => { test("throws error on mismatched lengths", () => {
...@@ -42,3 +50,32 @@ describe("similarity", () => { ...@@ -42,3 +50,32 @@ describe("similarity", () => {
); );
}); });
}); });
describe("[OpenAIEmbedding]", () => {
let embedModel: OpenAIEmbedding;
beforeAll(() => {
let openAIEmbedding = new OpenAIEmbedding();
mockEmbeddingModel(openAIEmbedding);
embedModel = openAIEmbedding;
});
test("getTextEmbedding", async () => {
const embedding = await embedModel.getTextEmbedding("hello");
expect(embedding.length).toEqual(6);
});
test("getTextEmbeddings", async () => {
const texts = ["hello", "world"];
const embeddings = await embedModel.getTextEmbeddings(texts);
expect(embeddings.length).toEqual(1);
});
test("getTextEmbeddingsBatch", async () => {
const texts = ["hello", "world"];
const embeddings = await embedModel.getTextEmbeddingsBatch(texts);
expect(embeddings.length).toEqual(1);
});
});
...@@ -90,6 +90,11 @@ export function mockEmbeddingModel(embedModel: OpenAIEmbedding) { ...@@ -90,6 +90,11 @@ export function mockEmbeddingModel(embedModel: OpenAIEmbedding) {
resolve([1, 0, 0, 0, 0, 0]); resolve([1, 0, 0, 0, 0, 0]);
}); });
}); });
jest.spyOn(embedModel, "getTextEmbeddings").mockImplementation(async (x) => {
return new Promise((resolve) => {
resolve([[1, 0, 0, 0, 0, 0]]);
});
});
jest.spyOn(embedModel, "getQueryEmbedding").mockImplementation(async (x) => { jest.spyOn(embedModel, "getQueryEmbedding").mockImplementation(async (x) => {
return new Promise((resolve) => { return new Promise((resolve) => {
resolve([0, 1, 0, 0, 0, 0]); resolve([0, 1, 0, 0, 0, 0]);
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment