From 325aa51e514be3433c02ddecb532c030f58d4c48 Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Wed, 17 Jul 2024 20:57:14 +0700
Subject: [PATCH] feat: implement Jina embedding through Jina api (#995)

Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
---
 .changeset/lovely-flies-press.md              |   5 +
 examples/multimodal/jina.ts                   |  83 +++++++++++
 .../src/embeddings/JinaAIEmbedding.ts         | 130 +++++++++++++++---
 packages/llamaindex/src/index.ts              |   1 +
 4 files changed, 203 insertions(+), 16 deletions(-)
 create mode 100644 .changeset/lovely-flies-press.md
 create mode 100644 examples/multimodal/jina.ts

diff --git a/.changeset/lovely-flies-press.md b/.changeset/lovely-flies-press.md
new file mode 100644
index 000000000..4becbde31
--- /dev/null
+++ b/.changeset/lovely-flies-press.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+Implement Jina embedding through Jina api
diff --git a/examples/multimodal/jina.ts b/examples/multimodal/jina.ts
new file mode 100644
index 000000000..7ce2f3cf8
--- /dev/null
+++ b/examples/multimodal/jina.ts
@@ -0,0 +1,83 @@
+import {
+  ImageDocument,
+  JinaAIEmbedding,
+  similarity,
+  SimilarityType,
+  SimpleDirectoryReader,
+} from "llamaindex";
+import path from "path";
+
+async function main() {
+  const jina = new JinaAIEmbedding({
+    model: "jina-clip-v1",
+  });
+
+  // Get text embeddings
+  const text1 = "a car";
+  const textEmbedding1 = await jina.getTextEmbedding(text1);
+  const text2 = "a football match";
+  const textEmbedding2 = await jina.getTextEmbedding(text2);
+
+  // Get image embedding
+  const image =
+    "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg";
+  const imageEmbedding = await jina.getImageEmbedding(image);
+
+  // Calc similarity between text and image
+  const sim1 = similarity(
+    textEmbedding1,
+    imageEmbedding,
+    SimilarityType.DEFAULT,
+  );
+  const sim2 = similarity(
+    textEmbedding2,
+    imageEmbedding,
+    SimilarityType.DEFAULT,
+  );
+
+  console.log(`Similarity between "${text1}" and the image is ${sim1}`);
+  console.log(`Similarity between "${text2}" and the image is ${sim2}`);
+
+  // Get multiple text embeddings
+  const textEmbeddings = await jina.getTextEmbeddings([text1, text2]);
+  const sim3 = similarity(
+    textEmbeddings[0],
+    textEmbeddings[1],
+    SimilarityType.DEFAULT,
+  );
+  console.log(
+    `Similarity between the two texts "${text1}" and "${text2}" is ${sim3}`,
+  );
+
+  // Get multiple image embeddings
+  const catImg1 =
+    "https://i.pinimg.com/600x315/21/48/7e/21487e8e0970dd366dafaed6ab25d8d8.jpg";
+  const catImg2 =
+    "https://i.pinimg.com/736x/c9/f2/3e/c9f23e212529f13f19bad5602d84b78b.jpg";
+  const imageEmbeddings = await jina.getImageEmbeddings([catImg1, catImg2]);
+  const sim4 = similarity(
+    imageEmbeddings[0],
+    imageEmbeddings[1],
+    SimilarityType.DEFAULT,
+  );
+  console.log(`Similarity between the two online cat images is ${sim4}`);
+
+  // Get image embeddings from multiple local files
+  const documents = await new SimpleDirectoryReader().loadData({
+    directoryPath: path.join("multimodal", "data"),
+  });
+  const localImages = documents
+    .filter((doc) => doc instanceof ImageDocument)
+    .slice(0, 2); // Get only the first two images
+  const localImageEmbeddings = await jina.getImageEmbeddings(
+    localImages.map((doc) => (doc as ImageDocument).image),
+  );
+  const sim5 = similarity(
+    localImageEmbeddings[0],
+    localImageEmbeddings[1],
+    SimilarityType.DEFAULT,
+  );
+  console.log(`Similarity between the two local images is ${sim5}`);
+}
+
+void main();
diff --git a/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts b/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts
index 9e2b95f5e..e1bc39972 100644
--- a/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts
+++ b/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts
@@ -1,29 +1,127 @@
 import { getEnv } from "@llamaindex/env";
-import { OpenAIEmbedding } from "./OpenAIEmbedding.js";
+import { imageToDataUrl } from "../internal/utils.js";
+import type { ImageType } from "../Node.js";
+import { MultiModalEmbedding } from "./MultiModalEmbedding.js";
 
-export class JinaAIEmbedding extends OpenAIEmbedding {
-  constructor(init?: Partial<OpenAIEmbedding>) {
-    const {
-      apiKey = getEnv("JINAAI_API_KEY"),
-      additionalSessionOptions = {},
-      model = "jina-embeddings-v2-base-en",
-      ...rest
-    } = init ?? {};
+function isLocal(url: ImageType): boolean {
+  if (url instanceof Blob) return true;
+  return new URL(url).protocol === "file:";
+}
+
+export type JinaEmbeddingRequest = {
+  input: Array<{ text: string } | { url: string } | { bytes: string }>;
+  model?: string;
+  encoding_type?: "float" | "binary" | "ubinary";
+};
+
+export type JinaEmbeddingResponse = {
+  model: string;
+  object: string;
+  usage: {
+    total_tokens: number;
+    prompt_tokens: number;
+  };
+  data: Array<{
+    object: string;
+    index: number;
+    embedding: number[];
+  }>;
+};
+
+const JINA_MULTIMODAL_MODELS = ["jina-clip-v1"];
+
+export class JinaAIEmbedding extends MultiModalEmbedding {
+  apiKey: string;
+  model: string;
+  baseURL: string;
+
+  async getTextEmbedding(text: string): Promise<number[]> {
+    const result = await this.getJinaEmbedding({ input: [{ text }] });
+    return result.data[0].embedding;
+  }
+
+  async getImageEmbedding(image: ImageType): Promise<number[]> {
+    const img = await this.getImageInput(image);
+    const result = await this.getJinaEmbedding({ input: [img] });
+    return result.data[0].embedding;
+  }
+
+  // Retrieve multiple text embeddings in a single request
+  getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => {
+    const input = texts.map((text) => ({ text }));
+    const result = await this.getJinaEmbedding({ input });
+    return result.data.map((d) => d.embedding);
+  };
+
+  // Retrieve multiple image embeddings in a single request
+  async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
+    const input = await Promise.all(
+      images.map((img) => this.getImageInput(img)),
+    );
+    const result = await this.getJinaEmbedding({ input });
+    return result.data.map((d) => d.embedding);
+  }
 
+  constructor(init?: Partial<JinaAIEmbedding>) {
+    super();
+    const apiKey = init?.apiKey ?? getEnv("JINAAI_API_KEY");
     if (!apiKey) {
       throw new Error(
         "Set Jina AI API Key in JINAAI_API_KEY env variable. Get one for free or top up your key at https://jina.ai/embeddings",
       );
     }
+    this.apiKey = apiKey;
+    this.model = init?.model ?? "jina-embeddings-v2-base-en";
+    this.baseURL = init?.baseURL ?? "https://api.jina.ai/v1/embeddings";
+    init?.embedBatchSize && (this.embedBatchSize = init?.embedBatchSize);
+  }
 
-    additionalSessionOptions.baseURL =
-      additionalSessionOptions.baseURL ?? "https://api.jina.ai/v1";
+  private async getImageInput(
+    image: ImageType,
+  ): Promise<{ bytes: string } | { url: string }> {
+    if (isLocal(image)) {
+      const base64 = await imageToDataUrl(image);
+      const bytes = base64.split(",")[1];
+      return { bytes };
+    } else {
+      return { url: image.toString() };
+    }
+  }
 
-    super({
-      apiKey,
-      additionalSessionOptions,
-      model,
-      ...rest,
+  private async getJinaEmbedding(
+    input: JinaEmbeddingRequest,
+  ): Promise<JinaEmbeddingResponse> {
+    // if input includes image, check if model supports multimodal embeddings
+    if (
+      input.input.some((i) => "url" in i || "bytes" in i) &&
+      !JINA_MULTIMODAL_MODELS.includes(this.model)
+    ) {
+      throw new Error(
+        `Model ${this.model} does not support image embeddings. Use ${JINA_MULTIMODAL_MODELS.join(", ")}`,
+      );
+    }
+
+    const response = await fetch(this.baseURL, {
+      method: "POST",
+      headers: {
+        "Content-Type": "application/json",
+        Authorization: `Bearer ${this.apiKey}`,
+      },
+      body: JSON.stringify({
+        model: this.model,
+        encoding_type: "float",
+        ...input,
+      }),
     });
+    if (!response.ok) {
+      throw new Error(
+        `Request ${this.baseURL} failed with status ${response.status}`,
+      );
+    }
+    const result: JinaEmbeddingResponse = await response.json();
+    return {
+      ...result,
+      data: result.data.sort((a, b) => a.index - b.index), // Sort resulting embeddings by index
+    };
   }
 }
diff --git a/packages/llamaindex/src/index.ts b/packages/llamaindex/src/index.ts
index d6d7a1353..b64397e7a 100644
--- a/packages/llamaindex/src/index.ts
+++ b/packages/llamaindex/src/index.ts
@@ -15,4 +15,5 @@ export { type VertexGeminiSessionOptions } from "./llm/gemini/types.js";
 export { GeminiVertexSession } from "./llm/gemini/vertex.js";
 
 // Expose AzureDynamicSessionTool for node.js runtime only
+export { JinaAIEmbedding } from "./embeddings/JinaAIEmbedding.js";
 export { AzureDynamicSessionTool } from "./tools/AzureDynamicSessionTool.node.js";
-- 
GitLab