From 4aa2c226a9d1d938fcb06582d8fd591af042fb84 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Mon, 20 Nov 2023 17:46:29 +0700 Subject: [PATCH] feat: add clip embedding to llamaindex --- apps/clip/.gitignore | 1 + apps/clip/1.js | 33 ----------- apps/clip/README.md | 19 ++++++ apps/clip/clip_test.js | 34 +++++++++++ apps/clip/package.json | 8 +-- packages/core/package.json | 8 +++ packages/core/src/Embedding.ts | 102 +++++++++++++++++++++++++++++++++ pnpm-lock.yaml | 16 ++---- 8 files changed, 172 insertions(+), 49 deletions(-) create mode 100644 apps/clip/.gitignore delete mode 100644 apps/clip/1.js create mode 100644 apps/clip/README.md create mode 100644 apps/clip/clip_test.js diff --git a/apps/clip/.gitignore b/apps/clip/.gitignore new file mode 100644 index 000000000..cff8f16ca --- /dev/null +++ b/apps/clip/.gitignore @@ -0,0 +1 @@ +test/data/** diff --git a/apps/clip/1.js b/apps/clip/1.js deleted file mode 100644 index d738e9c33..000000000 --- a/apps/clip/1.js +++ /dev/null @@ -1,33 +0,0 @@ -import { - AutoProcessor, - AutoTokenizer, - CLIPModel, - RawImage, -} from "@xenova/transformers"; - -async function main() { - // Load tokenizer, processor, and model - let tokenizer = await AutoTokenizer.from_pretrained( - "Xenova/clip-vit-base-patch32", - ); - let processor = await AutoProcessor.from_pretrained( - "Xenova/clip-vit-base-patch32", - ); - let model = await CLIPModel.from_pretrained("Xenova/clip-vit-base-patch32"); - - // Run tokenization - let texts = ["a photo of a car", "a photo of a football match"]; - let text_inputs = tokenizer(texts, { padding: true, truncation: true }); - - // Read image and run processor - let image = await RawImage.read( - "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg", - ); - let image_inputs = await processor(image); - - // Run model with both text and pixel inputs - let output = await model({ ...text_inputs, ...image_inputs }); - console.log(output); -} - -main(); diff --git a/apps/clip/README.md b/apps/clip/README.md new file mode 100644 index 000000000..085860cb6 --- /dev/null +++ b/apps/clip/README.md @@ -0,0 +1,19 @@ +# CLIP Embedding Example + +Uses the Clip model to embed images and text. + +## Get started + +Make sure, you have installed the local dev version of `llamaindex`, see [README.md](../../packages/core/README.md). + +Then, install dependencies: + +``` +pnpm install +``` + +Then call + +``` +node 1.js +``` diff --git a/apps/clip/clip_test.js b/apps/clip/clip_test.js new file mode 100644 index 000000000..386eae08e --- /dev/null +++ b/apps/clip/clip_test.js @@ -0,0 +1,34 @@ +/* eslint-disable turbo/no-undeclared-env-vars */ +import { ClipEmbedding, SimilarityType, similarity } from "llamaindex"; + +async function main() { + const clip = new ClipEmbedding(); + + // Get text embeddings + const text1 = "a car"; + const textEmbedding1 = await clip.getTextEmbedding(text1); + const text2 = "a football match"; + const textEmbedding2 = await clip.getTextEmbedding(text2); + + // Get image embedding + const image = + "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg"; + const imageEmbedding = await clip.getImageEmbedding(image); + + // Calc similarity + const sim1 = similarity( + textEmbedding1, + imageEmbedding, + SimilarityType.COSINE, + ); + const sim2 = similarity( + textEmbedding2, + imageEmbedding, + SimilarityType.COSINE, + ); + + console.log(`Similarity between "${text1}" and the image is ${sim1}`); + console.log(`Similarity between "${text2}" and the image is ${sim2}`); +} + +main(); diff --git a/apps/clip/package.json b/apps/clip/package.json index b98d22e9a..a4530e12d 100644 --- a/apps/clip/package.json +++ b/apps/clip/package.json @@ -1,16 +1,12 @@ { "version": "0.0.1", "private": true, - "name": "cliptest", + "name": "clip-test", "type": "module", "dependencies": { - "@xenova/transformers": "^2.8.0", + "dotenv": "^16.3.1", "llamaindex": "workspace:*" }, - "devDependencies": { - "@types/node": "^18.18.6", - "ts-node": "^10.9.1" - }, "scripts": { "lint": "eslint ." } diff --git a/packages/core/package.json b/packages/core/package.json index ef0dfbf8a..dcbec345b 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -5,6 +5,7 @@ "dependencies": { "@anthropic-ai/sdk": "^0.9.0", "@notionhq/client": "^2.2.13", + "@xenova/transformers": "^2.8.0", "crypto-js": "^4.2.0", "js-tiktoken": "^1.0.7", "lodash": "^4.17.21", @@ -45,5 +46,12 @@ "test": "jest", "build": "tsup src/index.ts --format esm,cjs --dts", "dev": "tsup src/index.ts --format esm,cjs --dts --watch" + }, + "exports": { + ".": { + "require": "./dist/index.js", + "import": "./dist/index.mjs", + "types": "./dist/index.d.ts" + } } } \ No newline at end of file diff --git a/packages/core/src/Embedding.ts b/packages/core/src/Embedding.ts index 5723f888c..78db4b937 100644 --- a/packages/core/src/Embedding.ts +++ b/packages/core/src/Embedding.ts @@ -1,5 +1,13 @@ import { ClientOptions as OpenAIClientOptions } from "openai"; +import { + AutoProcessor, + AutoTokenizer, + CLIPTextModelWithProjection, + CLIPVisionModelWithProjection, + RawImage, +} from "@xenova/transformers"; +import _ from "lodash"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; import { AzureOpenAIConfig, @@ -296,3 +304,97 @@ export class OpenAIEmbedding extends BaseEmbedding { return this.getOpenAIEmbedding(query); } } + +export type ImageType = string | Blob | URL; + +async function readImage(input: ImageType) { + if (input instanceof Blob) { + return await RawImage.fromBlob(input); + } else if (_.isString(input) || input instanceof URL) { + return await RawImage.fromURL(input); + } else { + throw new Error(`Unsupported input type: ${typeof input}`); + } +} + +/* + * Base class for Multi Modal embeddings. + */ +abstract class MultiModalEmbedding extends BaseEmbedding { + abstract getImageEmbedding(images: ImageType): Promise<number[]>; + + async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { + // Embed the input sequence of images asynchronously. + return Promise.all( + images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), + ); + } +} + +enum ClipEmbeddingModelType { + XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32", + XENOVA_CLIP_VIT_BASE_PATCH16 = "Xenova/clip-vit-base-patch16", +} + +export class ClipEmbedding extends MultiModalEmbedding { + modelType: ClipEmbeddingModelType = + ClipEmbeddingModelType.XENOVA_CLIP_VIT_BASE_PATCH16; + + private tokenizer: any; + private processor: any; + private visionModel: any; + private textModel: any; + + async getTokenizer() { + if (!this.tokenizer) { + this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType); + } + return this.tokenizer; + } + + async getProcessor() { + if (!this.processor) { + this.processor = await AutoProcessor.from_pretrained(this.modelType); + } + return this.processor; + } + + async getVisionModel() { + if (!this.visionModel) { + this.visionModel = await CLIPVisionModelWithProjection.from_pretrained( + this.modelType, + ); + } + + return this.visionModel; + } + + async getTextModel() { + if (!this.textModel) { + this.textModel = await CLIPTextModelWithProjection.from_pretrained( + this.modelType, + ); + } + + return this.textModel; + } + + async getImageEmbedding(image: ImageType): Promise<number[]> { + const loadedImage = await readImage(image); + const imageInputs = await (await this.getProcessor())(loadedImage); + const { image_embeds } = await (await this.getVisionModel())(imageInputs); + return image_embeds.data; + } + + async getTextEmbedding(text: string): Promise<number[]> { + const textInputs = await ( + await this.getTokenizer() + )([text], { padding: true, truncation: true }); + const { text_embeds } = await (await this.getTextModel())(textInputs); + return text_embeds.data; + } + + async getQueryEmbedding(query: string): Promise<number[]> { + return this.getTextEmbedding(query); + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index ff5488d39..95795125d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -51,19 +51,12 @@ importers: apps/clip: dependencies: - '@xenova/transformers': - specifier: ^2.8.0 - version: 2.8.0 + dotenv: + specifier: ^16.3.1 + version: 16.3.1 llamaindex: specifier: workspace:* version: link:../../packages/core - devDependencies: - '@types/node': - specifier: ^18.18.6 - version: 18.18.8 - ts-node: - specifier: ^10.9.1 - version: 10.9.1(@types/node@18.18.8)(typescript@5.2.2) apps/docs: dependencies: @@ -169,6 +162,9 @@ importers: '@notionhq/client': specifier: ^2.2.13 version: 2.2.13 + '@xenova/transformers': + specifier: ^2.8.0 + version: 2.8.0 crypto-js: specifier: ^4.2.0 version: 4.2.0 -- GitLab