diff --git a/.changeset/sour-snakes-fry.md b/.changeset/sour-snakes-fry.md new file mode 100644 index 0000000000000000000000000000000000000000..98808fd734c4bd3470960dfe82dd4c1b4ac70384 --- /dev/null +++ b/.changeset/sour-snakes-fry.md @@ -0,0 +1,6 @@ +--- +"llamaindex": patch +"@llamaindex/core": patch +--- + +fix: clip embedding transform function diff --git a/packages/core/src/embeddings/base.ts b/packages/core/src/embeddings/base.ts index d8f66d215ca98991b200a9cb18939dfb662faf62..1d1e972a504a556fad0b8bac2ebce3fe1f292e41 100644 --- a/packages/core/src/embeddings/base.ts +++ b/packages/core/src/embeddings/base.ts @@ -23,23 +23,34 @@ export abstract class BaseEmbedding extends TransformComponent { embedBatchSize = DEFAULT_EMBED_BATCH_SIZE; embedInfo?: EmbeddingInfo; - constructor() { - super( - async ( - nodes: BaseNode[], - options?: BaseEmbeddingOptions, - ): Promise<BaseNode[]> => { - const texts = nodes.map((node) => node.getContent(MetadataMode.EMBED)); - - const embeddings = await this.getTextEmbeddingsBatch(texts, options); - - for (let i = 0; i < nodes.length; i++) { - nodes[i]!.embedding = embeddings[i]; - } - - return nodes; - }, - ); + protected constructor( + transformFn?: ( + nodes: BaseNode[], + options?: BaseEmbeddingOptions, + ) => Promise<BaseNode[]>, + ) { + if (transformFn) { + super(transformFn); + } else { + super( + async ( + nodes: BaseNode[], + options?: BaseEmbeddingOptions, + ): Promise<BaseNode[]> => { + const texts = nodes.map((node) => + node.getContent(MetadataMode.EMBED), + ); + + const embeddings = await this.getTextEmbeddingsBatch(texts, options); + + for (let i = 0; i < nodes.length; i++) { + nodes[i]!.embedding = embeddings[i]; + } + + return nodes; + }, + ); + } } similarity( diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts index 7d6c4aac116ab281c3f0e96a14625cd545203e1b..5e96ef5b11fdda5aceb133d75288069d06f0a73b 100644 --- a/packages/core/src/embeddings/index.ts +++ b/packages/core/src/embeddings/index.ts @@ -1,4 +1,5 @@ export { BaseEmbedding, batchEmbeddings } from "./base"; export type { BaseEmbeddingOptions, EmbeddingInfo } from "./base"; +export { MultiModalEmbedding } from "./muti-model"; export { truncateMaxTokens } from "./tokenizer"; export { DEFAULT_SIMILARITY_TOP_K, SimilarityType, similarity } from "./utils"; diff --git a/packages/core/src/embeddings/muti-model.ts b/packages/core/src/embeddings/muti-model.ts new file mode 100644 index 0000000000000000000000000000000000000000..bf150c95aaf4436aefcb0a0cfb382dcd2a6491c5 --- /dev/null +++ b/packages/core/src/embeddings/muti-model.ts @@ -0,0 +1,81 @@ +import type { MessageContentDetail } from "../llms"; +import { + ImageNode, + MetadataMode, + ModalityType, + splitNodesByType, + type BaseNode, + type ImageType, +} from "../schema"; +import { extractImage, extractSingleText } from "../utils"; +import { + BaseEmbedding, + batchEmbeddings, + type BaseEmbeddingOptions, +} from "./base"; + +/* + * Base class for Multi Modal embeddings. + */ +export abstract class MultiModalEmbedding extends BaseEmbedding { + abstract getImageEmbedding(images: ImageType): Promise<number[]>; + + protected constructor() { + super( + async ( + nodes: BaseNode[], + options?: BaseEmbeddingOptions, + ): Promise<BaseNode[]> => { + const nodeMap = splitNodesByType(nodes); + const imageNodes = nodeMap[ModalityType.IMAGE] ?? []; + const textNodes = nodeMap[ModalityType.TEXT] ?? []; + + const embeddings = await batchEmbeddings( + textNodes.map((node) => node.getContent(MetadataMode.EMBED)), + this.getTextEmbeddings.bind(this), + this.embedBatchSize, + options, + ); + for (let i = 0; i < textNodes.length; i++) { + textNodes[i]!.embedding = embeddings[i]; + } + + const imageEmbeddings = await batchEmbeddings( + imageNodes.map((n) => (n as ImageNode).image), + this.getImageEmbeddings.bind(this), + this.embedBatchSize, + options, + ); + for (let i = 0; i < imageNodes.length; i++) { + imageNodes[i]!.embedding = imageEmbeddings[i]; + } + + return nodes; + }, + ); + } + + /** + * Optionally override this method to retrieve multiple image embeddings in a single request + * @param images + */ + async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { + return Promise.all( + images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), + ); + } + + async getQueryEmbedding( + query: MessageContentDetail, + ): Promise<number[] | null> { + const image = extractImage(query); + if (image) { + return await this.getImageEmbedding(image); + } + const text = extractSingleText(query); + if (text) { + return await this.getTextEmbedding(text); + } + return null; + } +} diff --git a/packages/llamaindex/e2e/fixtures/img/llamaindex-white.png b/packages/llamaindex/e2e/fixtures/img/llamaindex-white.png new file mode 100644 index 0000000000000000000000000000000000000000..7f72a5e04dc13a5076d80d71584e195f818e2132 Binary files /dev/null and b/packages/llamaindex/e2e/fixtures/img/llamaindex-white.png differ diff --git a/packages/llamaindex/e2e/node/embedding/clip.e2e.ts b/packages/llamaindex/e2e/node/embedding/clip.e2e.ts new file mode 100644 index 0000000000000000000000000000000000000000..110afe5f7fb2bbddb4e51b5fbcd3ae6ff7e03006 --- /dev/null +++ b/packages/llamaindex/e2e/node/embedding/clip.e2e.ts @@ -0,0 +1,30 @@ +import { ClipEmbedding, ImageNode } from "llamaindex"; +import assert from "node:assert"; +import { test } from "node:test"; + +await test("clip embedding", async (t) => { + await t.test("init & get image embedding", async () => { + const clipEmbedding = new ClipEmbedding(); + const imgUrl = new URL( + "../../fixtures/img/llamaindex-white.png", + import.meta.url, + ); + const vec = await clipEmbedding.getImageEmbedding(imgUrl); + assert.ok(vec); + }); + + await t.test("load image document", async () => { + const nodes = [ + new ImageNode({ + image: new URL( + "../../fixtures/img/llamaindex-white.png", + import.meta.url, + ), + }), + ]; + const clipEmbedding = new ClipEmbedding(); + const result = await clipEmbedding(nodes); + assert.strictEqual(result.length, 1); + assert.ok(result[0]!.embedding); + }); +}); diff --git a/packages/llamaindex/src/embeddings/ClipEmbedding.ts b/packages/llamaindex/src/embeddings/ClipEmbedding.ts index c613d0d158e2bf85652f5eda10404633bd564f95..d52d47cd908f3f8870342bbd56d49b589598265e 100644 --- a/packages/llamaindex/src/embeddings/ClipEmbedding.ts +++ b/packages/llamaindex/src/embeddings/ClipEmbedding.ts @@ -1,7 +1,7 @@ +import { MultiModalEmbedding } from "@llamaindex/core/embeddings"; import type { ImageType } from "@llamaindex/core/schema"; import _ from "lodash"; import { lazyLoadTransformers } from "../internal/deps/transformers.js"; -import { MultiModalEmbedding } from "./MultiModalEmbedding.js"; // only import type, to avoid bundling error import type { CLIPTextModelWithProjection, @@ -35,6 +35,10 @@ export class ClipEmbedding extends MultiModalEmbedding { private visionModel: CLIPVisionModelWithProjection | null = null; private textModel: CLIPTextModelWithProjection | null = null; + constructor() { + super(); + } + async getTokenizer() { const { AutoTokenizer } = await lazyLoadTransformers(); if (!this.tokenizer) { diff --git a/packages/llamaindex/src/embeddings/CloudflareWorkerEmbedding.ts b/packages/llamaindex/src/embeddings/CloudflareWorkerEmbedding.ts index 2034cc41e3726e791df571eb33a6870cede1114d..76c0fec6f5e00f8183735570f771e2bb6f90ee31 100644 --- a/packages/llamaindex/src/embeddings/CloudflareWorkerEmbedding.ts +++ b/packages/llamaindex/src/embeddings/CloudflareWorkerEmbedding.ts @@ -1,10 +1,13 @@ +import { MultiModalEmbedding } from "@llamaindex/core/embeddings"; import type { ImageType } from "@llamaindex/core/schema"; -import { MultiModalEmbedding } from "./MultiModalEmbedding.js"; /** * Cloudflare worker doesn't support image embeddings for now */ export class CloudflareWorkerMultiModalEmbedding extends MultiModalEmbedding { + constructor() { + super(); + } getImageEmbedding(images: ImageType): Promise<number[]> { throw new Error("Method not implemented."); } diff --git a/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts b/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts index b0514b998a050391f563186fb3e4a755490ab476..00188dc00156de04c51f103113745441c7d15d5c 100644 --- a/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts +++ b/packages/llamaindex/src/embeddings/JinaAIEmbedding.ts @@ -1,7 +1,7 @@ +import { MultiModalEmbedding } from "@llamaindex/core/embeddings"; import { getEnv } from "@llamaindex/env"; import { imageToDataUrl } from "../internal/utils.js"; import type { ImageType } from "../Node.js"; -import { MultiModalEmbedding } from "./MultiModalEmbedding.js"; function isLocal(url: ImageType): boolean { if (url instanceof Blob) return true; diff --git a/packages/llamaindex/src/embeddings/MultiModalEmbedding.ts b/packages/llamaindex/src/embeddings/MultiModalEmbedding.ts deleted file mode 100644 index 6d2ed5bb27746c4ace851bc5e5877dc95896e2dc..0000000000000000000000000000000000000000 --- a/packages/llamaindex/src/embeddings/MultiModalEmbedding.ts +++ /dev/null @@ -1,71 +0,0 @@ -import { BaseEmbedding, batchEmbeddings } from "@llamaindex/core/embeddings"; -import type { MessageContentDetail } from "@llamaindex/core/llms"; -import { - ImageNode, - MetadataMode, - ModalityType, - splitNodesByType, - type BaseNode, - type ImageType, -} from "@llamaindex/core/schema"; -import { extractImage, extractSingleText } from "@llamaindex/core/utils"; - -/* - * Base class for Multi Modal embeddings. - */ - -export abstract class MultiModalEmbedding extends BaseEmbedding { - abstract getImageEmbedding(images: ImageType): Promise<number[]>; - - /** - * Optionally override this method to retrieve multiple image embeddings in a single request - * @param images - */ - async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { - return Promise.all( - images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)), - ); - } - - async transform(nodes: BaseNode[], _options?: any): Promise<BaseNode[]> { - const nodeMap = splitNodesByType(nodes); - const imageNodes = nodeMap[ModalityType.IMAGE] ?? []; - const textNodes = nodeMap[ModalityType.TEXT] ?? []; - - const embeddings = await batchEmbeddings( - textNodes.map((node) => node.getContent(MetadataMode.EMBED)), - this.getTextEmbeddings.bind(this), - this.embedBatchSize, - _options, - ); - for (let i = 0; i < textNodes.length; i++) { - textNodes[i]!.embedding = embeddings[i]; - } - - const imageEmbeddings = await batchEmbeddings( - imageNodes.map((n) => (n as ImageNode).image), - this.getImageEmbeddings.bind(this), - this.embedBatchSize, - _options, - ); - for (let i = 0; i < imageNodes.length; i++) { - imageNodes[i]!.embedding = imageEmbeddings[i]; - } - - return nodes; - } - - async getQueryEmbedding( - query: MessageContentDetail, - ): Promise<number[] | null> { - const image = extractImage(query); - if (image) { - return await this.getImageEmbedding(image); - } - const text = extractSingleText(query); - if (text) { - return await this.getTextEmbedding(text); - } - return null; - } -} diff --git a/packages/llamaindex/src/embeddings/index.ts b/packages/llamaindex/src/embeddings/index.ts index 7d23905f13ee4217fdea9adb630d21a5ff890da2..a93a1d378a44dd06d7bbf06123bd3c6c605bef63 100644 --- a/packages/llamaindex/src/embeddings/index.ts +++ b/packages/llamaindex/src/embeddings/index.ts @@ -6,7 +6,6 @@ export { HuggingFaceInferenceAPIEmbedding } from "./HuggingFaceEmbedding.js"; export * from "./JinaAIEmbedding.js"; export * from "./MistralAIEmbedding.js"; export * from "./MixedbreadAIEmbeddings.js"; -export * from "./MultiModalEmbedding.js"; export { OllamaEmbedding } from "./OllamaEmbedding.js"; export * from "./OpenAIEmbedding.js"; export { TogetherEmbedding } from "./together.js";