From 325aa51e514be3433c02ddecb532c030f58d4c48 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Wed, 17 Jul 2024 20:57:14 +0700 Subject: [PATCH] feat: implement Jina embedding through Jina api (#995) Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> --- .changeset/lovely-flies-press.md | 5 + examples/multimodal/jina.ts | 83 +++++++++++ .../src/embeddings/JinaAIEmbedding.ts | 130 +++++++++++++++--- packages/llamaindex/src/index.ts | 1 + 4 files changed, 203 insertions(+), 16 deletions(-) create mode 100644 .changeset/lovely-flies-press.md create mode 100644 examples/multimodal/jina.ts diff --git a/.changeset/lovely-flies-press.md b/.changeset/lovely-flies-press.md new file mode 100644 index 000000000..4becbde31 --- /dev/null +++ b/.changeset/lovely-flies-press.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Implement Jina embedding through Jina api diff --git a/examples/multimodal/jina.ts b/examples/multimodal/jina.ts new file mode 100644 index 000000000..7ce2f3cf8 --- /dev/null +++ b/examples/multimodal/jina.ts @@ -0,0 +1,83 @@ +import { + ImageDocument, + JinaAIEmbedding, + similarity, + SimilarityType, + SimpleDirectoryReader, +} from "llamaindex"; +import path from "path"; + +async function main() { + const jina = new JinaAIEmbedding({ + model: "jina-clip-v1", + }); + + // Get text embeddings + const text1 = "a car"; + const textEmbedding1 = await jina.getTextEmbedding(text1); + const text2 = "a football match"; + const textEmbedding2 = await jina.getTextEmbedding(text2); + + // Get image embedding + const image = + "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg"; + const imageEmbedding = await jina.getImageEmbedding(image); + + // Calc similarity between text and image + const sim1 = similarity( + textEmbedding1, + imageEmbedding, + SimilarityType.DEFAULT, + ); + const sim2 = similarity( + textEmbedding2, + imageEmbedding, + SimilarityType.DEFAULT, + ); + + console.log(`Similarity between "${text1}" and the image is ${sim1}`); + console.log(`Similarity between "${text2}" and the image is ${sim2}`); + + // Get multiple text embeddings + const textEmbeddings = await jina.getTextEmbeddings([text1, text2]); + const sim3 = similarity( + textEmbeddings[0], + textEmbeddings[1], + SimilarityType.DEFAULT, + ); + console.log( + `Similarity between the two texts "${text1}" and "${text2}" is ${sim3}`, + ); + + // Get multiple image embeddings + const catImg1 = + "https://i.pinimg.com/600x315/21/48/7e/21487e8e0970dd366dafaed6ab25d8d8.jpg"; + const catImg2 = + "https://i.pinimg.com/736x/c9/f2/3e/c9f23e212529f13f19bad5602d84b78b.jpg"; + const imageEmbeddings = await jina.getImageEmbeddings([catImg1, catImg2]); + const sim4 = similarity( + imageEmbeddings[0], + imageEmbeddings[1], + SimilarityType.DEFAULT, + ); + console.log(`Similarity between the two online cat images is ${sim4}`); + + // Get image embeddings from multiple local files + const documents = await new SimpleDirectoryReader().loadData({ + directoryPath: path.join("multimodal", "data"), + }); + const localImages = documents + .filter((doc) => doc instanceof ImageDocument) + .slice(0, 2); // Get only the first two images + const localImageEmbeddings = await jina.getImageEmbeddings( + localImages.map((doc) => (doc as ImageDocument).image), + ); + const sim5 = similarity( + localImageEmbeddings[0], + localImageEmbeddings[1], + SimilarityType.DEFAULT, + ); + console.log(`Similarity between the two local images is ${sim5}`); +} + +void main(); diff --git a/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts b/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts index 9e2b95f5e..e1bc39972 100644 --- a/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts +++ b/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts @@ -1,29 +1,127 @@ import { getEnv } from "@llamaindex/env"; -import { OpenAIEmbedding } from "./OpenAIEmbedding.js"; +import { imageToDataUrl } from "../internal/utils.js"; +import type { ImageType } from "../Node.js"; +import { MultiModalEmbedding } from "./MultiModalEmbedding.js"; -export class JinaAIEmbedding extends OpenAIEmbedding { - constructor(init?: Partial<OpenAIEmbedding>) { - const { - apiKey = getEnv("JINAAI_API_KEY"), - additionalSessionOptions = {}, - model = "jina-embeddings-v2-base-en", - ...rest - } = init ?? {}; +function isLocal(url: ImageType): boolean { + if (url instanceof Blob) return true; + return new URL(url).protocol === "file:"; +} + +export type JinaEmbeddingRequest = { + input: Array<{ text: string } | { url: string } | { bytes: string }>; + model?: string; + encoding_type?: "float" | "binary" | "ubinary"; +}; + +export type JinaEmbeddingResponse = { + model: string; + object: string; + usage: { + total_tokens: number; + prompt_tokens: number; + }; + data: Array<{ + object: string; + index: number; + embedding: number[]; + }>; +}; + +const JINA_MULTIMODAL_MODELS = ["jina-clip-v1"]; + +export class JinaAIEmbedding extends MultiModalEmbedding { + apiKey: string; + model: string; + baseURL: string; + + async getTextEmbedding(text: string): Promise<number[]> { + const result = await this.getJinaEmbedding({ input: [{ text }] }); + return result.data[0].embedding; + } + + async getImageEmbedding(image: ImageType): Promise<number[]> { + const img = await this.getImageInput(image); + const result = await this.getJinaEmbedding({ input: [img] }); + return result.data[0].embedding; + } + + // Retrieve multiple text embeddings in a single request + getTextEmbeddings = async (texts: string[]): Promise<Array<number[]>> => { + const input = texts.map((text) => ({ text })); + const result = await this.getJinaEmbedding({ input }); + return result.data.map((d) => d.embedding); + }; + + // Retrieve multiple image embeddings in a single request + async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { + const input = await Promise.all( + images.map((img) => this.getImageInput(img)), + ); + const result = await this.getJinaEmbedding({ input }); + return result.data.map((d) => d.embedding); + } + constructor(init?: Partial<JinaAIEmbedding>) { + super(); + const apiKey = init?.apiKey ?? getEnv("JINAAI_API_KEY"); if (!apiKey) { throw new Error( "Set Jina AI API Key in JINAAI_API_KEY env variable. Get one for free or top up your key at https://jina.ai/embeddings", ); } + this.apiKey = apiKey; + this.model = init?.model ?? "jina-embeddings-v2-base-en"; + this.baseURL = init?.baseURL ?? "https://api.jina.ai/v1/embeddings"; + init?.embedBatchSize && (this.embedBatchSize = init?.embedBatchSize); + } - additionalSessionOptions.baseURL = - additionalSessionOptions.baseURL ?? "https://api.jina.ai/v1"; + private async getImageInput( + image: ImageType, + ): Promise<{ bytes: string } | { url: string }> { + if (isLocal(image)) { + const base64 = await imageToDataUrl(image); + const bytes = base64.split(",")[1]; + return { bytes }; + } else { + return { url: image.toString() }; + } + } - super({ - apiKey, - additionalSessionOptions, - model, - ...rest, + private async getJinaEmbedding( + input: JinaEmbeddingRequest, + ): Promise<JinaEmbeddingResponse> { + // if input includes image, check if model supports multimodal embeddings + if ( + input.input.some((i) => "url" in i || "bytes" in i) && + !JINA_MULTIMODAL_MODELS.includes(this.model) + ) { + throw new Error( + `Model ${this.model} does not support image embeddings. Use ${JINA_MULTIMODAL_MODELS.join(", ")}`, + ); + } + + const response = await fetch(this.baseURL, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, + }, + body: JSON.stringify({ + model: this.model, + encoding_type: "float", + ...input, + }), }); + if (!response.ok) { + throw new Error( + `Request ${this.baseURL} failed with status ${response.status}`, + ); + } + const result: JinaEmbeddingResponse = await response.json(); + return { + ...result, + data: result.data.sort((a, b) => a.index - b.index), // Sort resulting embeddings by index + }; } } diff --git a/packages/llamaindex/src/index.ts b/packages/llamaindex/src/index.ts index d6d7a1353..b64397e7a 100644 --- a/packages/llamaindex/src/index.ts +++ b/packages/llamaindex/src/index.ts @@ -15,4 +15,5 @@ export { type VertexGeminiSessionOptions } from "./llm/gemini/types.js"; export { GeminiVertexSession } from "./llm/gemini/vertex.js"; // Expose AzureDynamicSessionTool for node.js runtime only +export { JinaAIEmbedding } from "./embeddings/JinaAIEmbedding.js"; export { AzureDynamicSessionTool } from "./tools/AzureDynamicSessionTool.node.js"; -- GitLab