From 4aa2c226a9d1d938fcb06582d8fd591af042fb84 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Mon, 20 Nov 2023 17:46:29 +0700
Subject: [PATCH] feat: add clip embedding to llamaindex

---
 apps/clip/.gitignore           |   1 +
 apps/clip/1.js                 |  33 -----------
 apps/clip/README.md            |  19 ++++++
 apps/clip/clip_test.js         |  34 +++++++++++
 apps/clip/package.json         |   8 +--
 packages/core/package.json     |   8 +++
 packages/core/src/Embedding.ts | 102 +++++++++++++++++++++++++++++++++
 pnpm-lock.yaml                 |  16 ++----
 8 files changed, 172 insertions(+), 49 deletions(-)
 create mode 100644 apps/clip/.gitignore
 delete mode 100644 apps/clip/1.js
 create mode 100644 apps/clip/README.md
 create mode 100644 apps/clip/clip_test.js

diff --git a/apps/clip/.gitignore b/apps/clip/.gitignore
new file mode 100644
index 000000000..cff8f16ca
--- /dev/null
+++ b/apps/clip/.gitignore
@@ -0,0 +1 @@
+test/data/**
diff --git a/apps/clip/1.js b/apps/clip/1.js
deleted file mode 100644
index d738e9c33..000000000
--- a/apps/clip/1.js
+++ /dev/null
@@ -1,33 +0,0 @@
-import {
-  AutoProcessor,
-  AutoTokenizer,
-  CLIPModel,
-  RawImage,
-} from "@xenova/transformers";
-
-async function main() {
-  // Load tokenizer, processor, and model
-  let tokenizer = await AutoTokenizer.from_pretrained(
-    "Xenova/clip-vit-base-patch32",
-  );
-  let processor = await AutoProcessor.from_pretrained(
-    "Xenova/clip-vit-base-patch32",
-  );
-  let model = await CLIPModel.from_pretrained("Xenova/clip-vit-base-patch32");
-
-  // Run tokenization
-  let texts = ["a photo of a car", "a photo of a football match"];
-  let text_inputs = tokenizer(texts, { padding: true, truncation: true });
-
-  // Read image and run processor
-  let image = await RawImage.read(
-    "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg",
-  );
-  let image_inputs = await processor(image);
-
-  // Run model with both text and pixel inputs
-  let output = await model({ ...text_inputs, ...image_inputs });
-  console.log(output);
-}
-
-main();
diff --git a/apps/clip/README.md b/apps/clip/README.md
new file mode 100644
index 000000000..085860cb6
--- /dev/null
+++ b/apps/clip/README.md
@@ -0,0 +1,19 @@
+# CLIP Embedding Example
+
+Uses the Clip model to embed images and text.
+
+## Get started
+
+Make sure, you have installed the local dev version of `llamaindex`, see [README.md](../../packages/core/README.md).
+
+Then, install dependencies:
+
+```
+pnpm install
+```
+
+Then call
+
+```
+node 1.js
+```
diff --git a/apps/clip/clip_test.js b/apps/clip/clip_test.js
new file mode 100644
index 000000000..386eae08e
--- /dev/null
+++ b/apps/clip/clip_test.js
@@ -0,0 +1,34 @@
+/* eslint-disable turbo/no-undeclared-env-vars */
+import { ClipEmbedding, SimilarityType, similarity } from "llamaindex";
+
+async function main() {
+  const clip = new ClipEmbedding();
+
+  // Get text embeddings
+  const text1 = "a car";
+  const textEmbedding1 = await clip.getTextEmbedding(text1);
+  const text2 = "a football match";
+  const textEmbedding2 = await clip.getTextEmbedding(text2);
+
+  // Get image embedding
+  const image =
+    "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg";
+  const imageEmbedding = await clip.getImageEmbedding(image);
+
+  // Calc similarity
+  const sim1 = similarity(
+    textEmbedding1,
+    imageEmbedding,
+    SimilarityType.COSINE,
+  );
+  const sim2 = similarity(
+    textEmbedding2,
+    imageEmbedding,
+    SimilarityType.COSINE,
+  );
+
+  console.log(`Similarity between "${text1}" and the image is ${sim1}`);
+  console.log(`Similarity between "${text2}" and the image is ${sim2}`);
+}
+
+main();
diff --git a/apps/clip/package.json b/apps/clip/package.json
index b98d22e9a..a4530e12d 100644
--- a/apps/clip/package.json
+++ b/apps/clip/package.json
@@ -1,16 +1,12 @@
 {
   "version": "0.0.1",
   "private": true,
-  "name": "cliptest",
+  "name": "clip-test",
   "type": "module",
   "dependencies": {
-    "@xenova/transformers": "^2.8.0",
+    "dotenv": "^16.3.1",
     "llamaindex": "workspace:*"
   },
-  "devDependencies": {
-    "@types/node": "^18.18.6",
-    "ts-node": "^10.9.1"
-  },
   "scripts": {
     "lint": "eslint ."
   }
diff --git a/packages/core/package.json b/packages/core/package.json
index ef0dfbf8a..dcbec345b 100644
--- a/packages/core/package.json
+++ b/packages/core/package.json
@@ -5,6 +5,7 @@
   "dependencies": {
     "@anthropic-ai/sdk": "^0.9.0",
     "@notionhq/client": "^2.2.13",
+    "@xenova/transformers": "^2.8.0",
     "crypto-js": "^4.2.0",
     "js-tiktoken": "^1.0.7",
     "lodash": "^4.17.21",
@@ -45,5 +46,12 @@
     "test": "jest",
     "build": "tsup src/index.ts --format esm,cjs --dts",
     "dev": "tsup src/index.ts --format esm,cjs --dts --watch"
+  },
+  "exports": {
+    ".": {
+      "require": "./dist/index.js",
+      "import": "./dist/index.mjs",
+      "types": "./dist/index.d.ts"
+    }
   }
 }
\ No newline at end of file
diff --git a/packages/core/src/Embedding.ts b/packages/core/src/Embedding.ts
index 5723f888c..78db4b937 100644
--- a/packages/core/src/Embedding.ts
+++ b/packages/core/src/Embedding.ts
@@ -1,5 +1,13 @@
 import { ClientOptions as OpenAIClientOptions } from "openai";
 
+import {
+  AutoProcessor,
+  AutoTokenizer,
+  CLIPTextModelWithProjection,
+  CLIPVisionModelWithProjection,
+  RawImage,
+} from "@xenova/transformers";
+import _ from "lodash";
 import { DEFAULT_SIMILARITY_TOP_K } from "./constants";
 import {
   AzureOpenAIConfig,
@@ -296,3 +304,97 @@ export class OpenAIEmbedding extends BaseEmbedding {
     return this.getOpenAIEmbedding(query);
   }
 }
+
+export type ImageType = string | Blob | URL;
+
+async function readImage(input: ImageType) {
+  if (input instanceof Blob) {
+    return await RawImage.fromBlob(input);
+  } else if (_.isString(input) || input instanceof URL) {
+    return await RawImage.fromURL(input);
+  } else {
+    throw new Error(`Unsupported input type: ${typeof input}`);
+  }
+}
+
+/*
+ * Base class for Multi Modal embeddings.
+ */
+abstract class MultiModalEmbedding extends BaseEmbedding {
+  abstract getImageEmbedding(images: ImageType): Promise<number[]>;
+
+  async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
+    // Embed the input sequence of images asynchronously.
+    return Promise.all(
+      images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)),
+    );
+  }
+}
+
+enum ClipEmbeddingModelType {
+  XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32",
+  XENOVA_CLIP_VIT_BASE_PATCH16 = "Xenova/clip-vit-base-patch16",
+}
+
+export class ClipEmbedding extends MultiModalEmbedding {
+  modelType: ClipEmbeddingModelType =
+    ClipEmbeddingModelType.XENOVA_CLIP_VIT_BASE_PATCH16;
+
+  private tokenizer: any;
+  private processor: any;
+  private visionModel: any;
+  private textModel: any;
+
+  async getTokenizer() {
+    if (!this.tokenizer) {
+      this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType);
+    }
+    return this.tokenizer;
+  }
+
+  async getProcessor() {
+    if (!this.processor) {
+      this.processor = await AutoProcessor.from_pretrained(this.modelType);
+    }
+    return this.processor;
+  }
+
+  async getVisionModel() {
+    if (!this.visionModel) {
+      this.visionModel = await CLIPVisionModelWithProjection.from_pretrained(
+        this.modelType,
+      );
+    }
+
+    return this.visionModel;
+  }
+
+  async getTextModel() {
+    if (!this.textModel) {
+      this.textModel = await CLIPTextModelWithProjection.from_pretrained(
+        this.modelType,
+      );
+    }
+
+    return this.textModel;
+  }
+
+  async getImageEmbedding(image: ImageType): Promise<number[]> {
+    const loadedImage = await readImage(image);
+    const imageInputs = await (await this.getProcessor())(loadedImage);
+    const { image_embeds } = await (await this.getVisionModel())(imageInputs);
+    return image_embeds.data;
+  }
+
+  async getTextEmbedding(text: string): Promise<number[]> {
+    const textInputs = await (
+      await this.getTokenizer()
+    )([text], { padding: true, truncation: true });
+    const { text_embeds } = await (await this.getTextModel())(textInputs);
+    return text_embeds.data;
+  }
+
+  async getQueryEmbedding(query: string): Promise<number[]> {
+    return this.getTextEmbedding(query);
+  }
+}
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index ff5488d39..95795125d 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -51,19 +51,12 @@ importers:
 
   apps/clip:
     dependencies:
-      '@xenova/transformers':
-        specifier: ^2.8.0
-        version: 2.8.0
+      dotenv:
+        specifier: ^16.3.1
+        version: 16.3.1
       llamaindex:
         specifier: workspace:*
         version: link:../../packages/core
-    devDependencies:
-      '@types/node':
-        specifier: ^18.18.6
-        version: 18.18.8
-      ts-node:
-        specifier: ^10.9.1
-        version: 10.9.1(@types/node@18.18.8)(typescript@5.2.2)
 
   apps/docs:
     dependencies:
@@ -169,6 +162,9 @@ importers:
       '@notionhq/client':
         specifier: ^2.2.13
         version: 2.2.13
+      '@xenova/transformers':
+        specifier: ^2.8.0
+        version: 2.8.0
       crypto-js:
         specifier: ^4.2.0
         version: 4.2.0
-- 
GitLab