From bb917f9818e514d70e1923972d48477987c2d298 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Tue, 21 Nov 2023 14:20:10 +0700 Subject: [PATCH] refactor: moved embeddings to embeddings folder --- .vscode/settings.json | 5 +- packages/core/src/ServiceContext.ts | 2 +- packages/core/src/embeddings/ClipEmbedding.ts | 78 +++++++ .../src/embeddings/MultiModalEmbedding.ts | 17 ++ .../core/src/embeddings/OpenAIEmbedding.ts | 92 ++++++++ packages/core/src/embeddings/index.ts | 5 + packages/core/src/embeddings/types.ts | 24 ++ .../src/{Embedding.ts => embeddings/utils.ts} | 219 +----------------- packages/core/src/index.ts | 2 +- .../storage/vectorStore/SimpleVectorStore.ts | 4 +- .../core/src/tests/CallbackManager.test.ts | 2 +- packages/core/src/tests/Embedding.test.ts | 2 +- packages/core/src/tests/utility/mockOpenAI.ts | 2 +- 13 files changed, 233 insertions(+), 221 deletions(-) create mode 100644 packages/core/src/embeddings/ClipEmbedding.ts create mode 100644 packages/core/src/embeddings/MultiModalEmbedding.ts create mode 100644 packages/core/src/embeddings/OpenAIEmbedding.ts create mode 100644 packages/core/src/embeddings/index.ts create mode 100644 packages/core/src/embeddings/types.ts rename packages/core/src/{Embedding.ts => embeddings/utils.ts} (50%) diff --git a/.vscode/settings.json b/.vscode/settings.json index d3a0c1169..9f6017380 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,5 +4,6 @@ "editor.defaultFormatter": "esbenp.prettier-vscode", "[xml]": { "editor.defaultFormatter": "redhat.vscode-xml" - } -} + }, + "jest.rootPath": "./packages/core" +} \ No newline at end of file diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index efefd319b..da33be209 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -1,7 +1,7 @@ -import { BaseEmbedding, OpenAIEmbedding } from "./Embedding"; import { NodeParser, SimpleNodeParser } from "./NodeParser"; import { PromptHelper } from "./PromptHelper"; import { CallbackManager } from "./callbacks/CallbackManager"; +import { BaseEmbedding, OpenAIEmbedding } from "./embeddings"; import { LLM, OpenAI } from "./llm/LLM"; /** diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts new file mode 100644 index 000000000..b75b4b879 --- /dev/null +++ b/packages/core/src/embeddings/ClipEmbedding.ts @@ -0,0 +1,78 @@ +import { MultiModalEmbedding } from "./MultiModalEmbedding"; +import { ImageType, readImage } from "./utils"; + +export enum ClipEmbeddingModelType { + XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32", + XENOVA_CLIP_VIT_BASE_PATCH16 = "Xenova/clip-vit-base-patch16", +} + +export class ClipEmbedding extends MultiModalEmbedding { + modelType: ClipEmbeddingModelType = + ClipEmbeddingModelType.XENOVA_CLIP_VIT_BASE_PATCH16; + + private tokenizer: any; + private processor: any; + private visionModel: any; + private textModel: any; + + async getTokenizer() { + if (!this.tokenizer) { + const { AutoTokenizer } = await import("@xenova/transformers"); + this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType); + } + return this.tokenizer; + } + + async getProcessor() { + if (!this.processor) { + const { AutoProcessor } = await import("@xenova/transformers"); + this.processor = await AutoProcessor.from_pretrained(this.modelType); + } + return this.processor; + } + + async getVisionModel() { + if (!this.visionModel) { + const { CLIPVisionModelWithProjection } = await import( + "@xenova/transformers" + ); + this.visionModel = await CLIPVisionModelWithProjection.from_pretrained( + this.modelType, + ); + } + + return this.visionModel; + } + + async getTextModel() { + if (!this.textModel) { + const { CLIPTextModelWithProjection } = await import( + "@xenova/transformers" + ); + this.textModel = await CLIPTextModelWithProjection.from_pretrained( + this.modelType, + ); + } + + return this.textModel; + } + + async getImageEmbedding(image: ImageType): Promise<number[]> { + const loadedImage = await readImage(image); + const imageInputs = await (await this.getProcessor())(loadedImage); + const { image_embeds } = await (await this.getVisionModel())(imageInputs); + return image_embeds.data; + } + + async getTextEmbedding(text: string): Promise<number[]> { + const textInputs = await ( + await this.getTokenizer() + )([text], { padding: true, truncation: true }); + const { text_embeds } = await (await this.getTextModel())(textInputs); + return text_embeds.data; + } + + async getQueryEmbedding(query: string): Promise<number[]> { + return this.getTextEmbedding(query); + } +} diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts new file mode 100644 index 000000000..c86ba0721 --- /dev/null +++ b/packages/core/src/embeddings/MultiModalEmbedding.ts @@ -0,0 +1,17 @@ +import { BaseEmbedding } from "./types"; +import { ImageType } from "./utils"; + +/* + * Base class for Multi Modal embeddings. + */ + +export abstract class MultiModalEmbedding extends BaseEmbedding { + abstract getImageEmbedding(images: ImageType): Promise<number[]>; + + async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { + // Embed the input sequence of images asynchronously. + return Promise.all( + images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), + ); + } +} diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts new file mode 100644 index 000000000..106c6cbff --- /dev/null +++ b/packages/core/src/embeddings/OpenAIEmbedding.ts @@ -0,0 +1,92 @@ +import { ClientOptions as OpenAIClientOptions } from "openai"; +import { + AzureOpenAIConfig, + getAzureBaseUrl, + getAzureConfigFromEnv, + getAzureModel, + shouldUseAzure, +} from "../llm/azure"; +import { OpenAISession, getOpenAISession } from "../llm/openai"; +import { BaseEmbedding } from "./types"; + +export enum OpenAIEmbeddingModelType { + TEXT_EMBED_ADA_002 = "text-embedding-ada-002", +} + +export class OpenAIEmbedding extends BaseEmbedding { + model: OpenAIEmbeddingModelType; + + // OpenAI session params + apiKey?: string = undefined; + maxRetries: number; + timeout?: number; + additionalSessionOptions?: Omit< + Partial<OpenAIClientOptions>, + "apiKey" | "maxRetries" | "timeout" + >; + + session: OpenAISession; + + constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) { + super(); + + this.model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002; + + this.maxRetries = init?.maxRetries ?? 10; + this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds + this.additionalSessionOptions = init?.additionalSessionOptions; + + if (init?.azure || shouldUseAzure()) { + const azureConfig = getAzureConfigFromEnv({ + ...init?.azure, + model: getAzureModel(this.model), + }); + + if (!azureConfig.apiKey) { + throw new Error( + "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.", + ); + } + + this.apiKey = azureConfig.apiKey; + this.session = + init?.session ?? + getOpenAISession({ + azure: true, + apiKey: this.apiKey, + baseURL: getAzureBaseUrl(azureConfig), + maxRetries: this.maxRetries, + timeout: this.timeout, + defaultQuery: { "api-version": azureConfig.apiVersion }, + ...this.additionalSessionOptions, + }); + } else { + this.apiKey = init?.apiKey ?? undefined; + this.session = + init?.session ?? + getOpenAISession({ + apiKey: this.apiKey, + maxRetries: this.maxRetries, + timeout: this.timeout, + ...this.additionalSessionOptions, + }); + } + } + + private async getOpenAIEmbedding(input: string) { + const { data } = await this.session.openai.embeddings.create({ + model: this.model, + input, + }); + + return data[0].embedding; + } + + async getTextEmbedding(text: string): Promise<number[]> { + return this.getOpenAIEmbedding(text); + } + + async getQueryEmbedding(query: string): Promise<number[]> { + return this.getOpenAIEmbedding(query); + } +} diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts new file mode 100644 index 000000000..1a6a4df04 --- /dev/null +++ b/packages/core/src/embeddings/index.ts @@ -0,0 +1,5 @@ +export * from "./ClipEmbedding"; +export * from "./MultiModalEmbedding"; +export * from "./OpenAIEmbedding"; +export * from "./types"; +export * from "./utils"; diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts new file mode 100644 index 000000000..e500f9452 --- /dev/null +++ b/packages/core/src/embeddings/types.ts @@ -0,0 +1,24 @@ +import { similarity } from "./utils"; + +/** + * Similarity type + * Default is cosine similarity. Dot product and negative Euclidean distance are also supported. + */ +export enum SimilarityType { + DEFAULT = "cosine", + DOT_PRODUCT = "dot_product", + EUCLIDEAN = "euclidean", +} + +export abstract class BaseEmbedding { + similarity( + embedding1: number[], + embedding2: number[], + mode: SimilarityType = SimilarityType.DEFAULT, + ): number { + return similarity(embedding1, embedding2, mode); + } + + abstract getTextEmbedding(text: string): Promise<number[]>; + abstract getQueryEmbedding(query: string): Promise<number[]>; +} diff --git a/packages/core/src/Embedding.ts b/packages/core/src/embeddings/utils.ts similarity index 50% rename from packages/core/src/Embedding.ts rename to packages/core/src/embeddings/utils.ts index ad6650251..cd192c3d4 100644 --- a/packages/core/src/Embedding.ts +++ b/packages/core/src/embeddings/utils.ts @@ -1,33 +1,16 @@ import _ from "lodash"; -import { ClientOptions as OpenAIClientOptions } from "openai"; -import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; -import { - AzureOpenAIConfig, - getAzureBaseUrl, - getAzureConfigFromEnv, - getAzureModel, - shouldUseAzure, -} from "./llm/azure"; -import { OpenAISession, getOpenAISession } from "./llm/openai"; -import { VectorStoreQueryMode } from "./storage/vectorStore/types"; - -/** - * Similarity type - * Default is cosine similarity. Dot product and negative Euclidean distance are also supported. - */ -export enum SimilarityType { - DEFAULT = "cosine", - DOT_PRODUCT = "dot_product", - EUCLIDEAN = "euclidean", -} +import { DEFAULT_SIMILARITY_TOP_K } from "../constants"; +import { VectorStoreQueryMode } from "../storage"; +import { SimilarityType } from "./types"; /** * The similarity between two embeddings. * @param embedding1 * @param embedding2 * @param mode - * @returns similartiy score with higher numbers meaning the two embeddings are more similar + * @returns similarity score with higher numbers meaning the two embeddings are more similar */ + export function similarity( embedding1: number[], embedding2: number[], @@ -42,7 +25,6 @@ export function similarity( // will probably cause some avoidable loss of floating point precision // ml-distance is worth watching although they currently also use the naive // formulas - function norm(x: number[]): number { let result = 0; for (let i = 0; i < x.length; i++) { @@ -201,105 +183,7 @@ export function getTopKMMREmbeddings( return [resultSimilarities, resultIds]; } - -export abstract class BaseEmbedding { - similarity( - embedding1: number[], - embedding2: number[], - mode: SimilarityType = SimilarityType.DEFAULT, - ): number { - return similarity(embedding1, embedding2, mode); - } - - abstract getTextEmbedding(text: string): Promise<number[]>; - abstract getQueryEmbedding(query: string): Promise<number[]>; -} - -enum OpenAIEmbeddingModelType { - TEXT_EMBED_ADA_002 = "text-embedding-ada-002", -} - -export class OpenAIEmbedding extends BaseEmbedding { - model: OpenAIEmbeddingModelType; - - // OpenAI session params - apiKey?: string = undefined; - maxRetries: number; - timeout?: number; - additionalSessionOptions?: Omit< - Partial<OpenAIClientOptions>, - "apiKey" | "maxRetries" | "timeout" - >; - - session: OpenAISession; - - constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) { - super(); - - this.model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002; - - this.maxRetries = init?.maxRetries ?? 10; - this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds - this.additionalSessionOptions = init?.additionalSessionOptions; - - if (init?.azure || shouldUseAzure()) { - const azureConfig = getAzureConfigFromEnv({ - ...init?.azure, - model: getAzureModel(this.model), - }); - - if (!azureConfig.apiKey) { - throw new Error( - "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.", - ); - } - - this.apiKey = azureConfig.apiKey; - this.session = - init?.session ?? - getOpenAISession({ - azure: true, - apiKey: this.apiKey, - baseURL: getAzureBaseUrl(azureConfig), - maxRetries: this.maxRetries, - timeout: this.timeout, - defaultQuery: { "api-version": azureConfig.apiVersion }, - ...this.additionalSessionOptions, - }); - } else { - this.apiKey = init?.apiKey ?? undefined; - this.session = - init?.session ?? - getOpenAISession({ - apiKey: this.apiKey, - maxRetries: this.maxRetries, - timeout: this.timeout, - ...this.additionalSessionOptions, - }); - } - } - - private async getOpenAIEmbedding(input: string) { - const { data } = await this.session.openai.embeddings.create({ - model: this.model, - input, - }); - - return data[0].embedding; - } - - async getTextEmbedding(text: string): Promise<number[]> { - return this.getOpenAIEmbedding(text); - } - - async getQueryEmbedding(query: string): Promise<number[]> { - return this.getOpenAIEmbedding(query); - } -} - -export type ImageType = string | Blob | URL; - -async function readImage(input: ImageType) { +export async function readImage(input: ImageType) { const { RawImage } = await import("@xenova/transformers"); if (input instanceof Blob) { return await RawImage.fromBlob(input); @@ -309,93 +193,4 @@ async function readImage(input: ImageType) { throw new Error(`Unsupported input type: ${typeof input}`); } } - -/* - * Base class for Multi Modal embeddings. - */ -export abstract class MultiModalEmbedding extends BaseEmbedding { - abstract getImageEmbedding(images: ImageType): Promise<number[]>; - - async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { - // Embed the input sequence of images asynchronously. - return Promise.all( - images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), - ); - } -} - -enum ClipEmbeddingModelType { - XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32", - XENOVA_CLIP_VIT_BASE_PATCH16 = "Xenova/clip-vit-base-patch16", -} - -export class ClipEmbedding extends MultiModalEmbedding { - modelType: ClipEmbeddingModelType = - ClipEmbeddingModelType.XENOVA_CLIP_VIT_BASE_PATCH16; - - private tokenizer: any; - private processor: any; - private visionModel: any; - private textModel: any; - - async getTokenizer() { - if (!this.tokenizer) { - const { AutoTokenizer } = await import("@xenova/transformers"); - this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType); - } - return this.tokenizer; - } - - async getProcessor() { - if (!this.processor) { - const { AutoProcessor } = await import("@xenova/transformers"); - this.processor = await AutoProcessor.from_pretrained(this.modelType); - } - return this.processor; - } - - async getVisionModel() { - if (!this.visionModel) { - const { CLIPVisionModelWithProjection } = await import( - "@xenova/transformers" - ); - this.visionModel = await CLIPVisionModelWithProjection.from_pretrained( - this.modelType, - ); - } - - return this.visionModel; - } - - async getTextModel() { - if (!this.textModel) { - const { CLIPTextModelWithProjection } = await import( - "@xenova/transformers" - ); - this.textModel = await CLIPTextModelWithProjection.from_pretrained( - this.modelType, - ); - } - - return this.textModel; - } - - async getImageEmbedding(image: ImageType): Promise<number[]> { - const loadedImage = await readImage(image); - const imageInputs = await (await this.getProcessor())(loadedImage); - const { image_embeds } = await (await this.getVisionModel())(imageInputs); - return image_embeds.data; - } - - async getTextEmbedding(text: string): Promise<number[]> { - const textInputs = await ( - await this.getTokenizer() - )([text], { padding: true, truncation: true }); - const { text_embeds } = await (await this.getTextModel())(textInputs); - return text_embeds.data; - } - - async getQueryEmbedding(query: string): Promise<number[]> { - return this.getTextEmbedding(query); - } -} +export type ImageType = string | Blob | URL; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 20ab46297..dde8fff26 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,6 +1,5 @@ export * from "./ChatEngine"; export * from "./ChatHistory"; -export * from "./Embedding"; export * from "./GlobalsHelper"; export * from "./Node"; export * from "./NodeParser"; @@ -17,6 +16,7 @@ export * from "./TextSplitter"; export * from "./Tool"; export * from "./callbacks/CallbackManager"; export * from "./constants"; +export * from "./embeddings"; export * from "./indices"; export * from "./llm/LLM"; export * from "./readers/CSVReader"; diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts index 1bccf12ff..929ebe2c2 100644 --- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts @@ -1,11 +1,11 @@ import _ from "lodash"; import * as path from "path"; +import { BaseNode } from "../../Node"; import { getTopKEmbeddings, getTopKEmbeddingsLearner, getTopKMMREmbeddings, -} from "../../Embedding"; -import { BaseNode } from "../../Node"; +} from "../../embeddings"; import { GenericFileSystem, exists } from "../FileSystem"; import { DEFAULT_FS, DEFAULT_PERSIST_DIR } from "../constants"; import { diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index c3d9a98d4..9374c50b6 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -1,4 +1,3 @@ -import { OpenAIEmbedding } from "../Embedding"; import { Document } from "../Node"; import { ResponseSynthesizer, @@ -10,6 +9,7 @@ import { RetrievalCallbackResponse, StreamCallbackResponse, } from "../callbacks/CallbackManager"; +import { OpenAIEmbedding } from "../embeddings"; import { SummaryIndex } from "../indices/summary"; import { VectorStoreIndex } from "../indices/vectorStore/VectorStoreIndex"; import { OpenAI } from "../llm/LLM"; diff --git a/packages/core/src/tests/Embedding.test.ts b/packages/core/src/tests/Embedding.test.ts index 492a48be1..adc70810f 100644 --- a/packages/core/src/tests/Embedding.test.ts +++ b/packages/core/src/tests/Embedding.test.ts @@ -1,4 +1,4 @@ -import { SimilarityType, similarity } from "../Embedding"; +import { SimilarityType, similarity } from "../embeddings"; describe("similarity", () => { test("throws error on mismatched lengths", () => { diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts index 04d6cdd40..5f8a0ad4c 100644 --- a/packages/core/src/tests/utility/mockOpenAI.ts +++ b/packages/core/src/tests/utility/mockOpenAI.ts @@ -1,6 +1,6 @@ -import { OpenAIEmbedding } from "../../Embedding"; import { globalsHelper } from "../../GlobalsHelper"; import { CallbackManager, Event } from "../../callbacks/CallbackManager"; +import { OpenAIEmbedding } from "../../embeddings"; import { ChatMessage, OpenAI } from "../../llm/LLM"; export function mockLlmGeneration({ -- GitLab