diff --git a/apps/clip/clip_test.js b/apps/clip/clip_test.ts similarity index 94% rename from apps/clip/clip_test.js rename to apps/clip/clip_test.ts index 386eae08e8e2cbf343ee3080a722a2c2db6aaad6..5f2da74a1d2224701ea6748f69a3d31c733d98d1 100644 --- a/apps/clip/clip_test.js +++ b/apps/clip/clip_test.ts @@ -19,12 +19,12 @@ async function main() { const sim1 = similarity( textEmbedding1, imageEmbedding, - SimilarityType.COSINE, + SimilarityType.DEFAULT, ); const sim2 = similarity( textEmbedding2, imageEmbedding, - SimilarityType.COSINE, + SimilarityType.DEFAULT, ); console.log(`Similarity between "${text1}" and the image is ${sim1}`); diff --git a/apps/clip/package.json b/apps/clip/package.json index a4530e12d6be4d694364c1c6f01f4b5ca4648700..99cec39038e06604e2a3ecc86fcab4562177aa89 100644 --- a/apps/clip/package.json +++ b/apps/clip/package.json @@ -2,12 +2,16 @@ "version": "0.0.1", "private": true, "name": "clip-test", - "type": "module", "dependencies": { "dotenv": "^16.3.1", "llamaindex": "workspace:*" }, + "devDependencies": { + "@types/node": "^18", + "ts-node": "^10.9.1" + }, "scripts": { - "lint": "eslint ." + "lint": "eslint .", + "start": "ts-node ./clip_test.ts" } } \ No newline at end of file diff --git a/packages/core/src/Embedding.ts b/packages/core/src/Embedding.ts index 78db4b93753398c20a7cd205de25ce2b71e92bde..ad6650251c6f622ba0b190f24a57433a58c88814 100644 --- a/packages/core/src/Embedding.ts +++ b/packages/core/src/Embedding.ts @@ -1,13 +1,5 @@ -import { ClientOptions as OpenAIClientOptions } from "openai"; - -import { - AutoProcessor, - AutoTokenizer, - CLIPTextModelWithProjection, - CLIPVisionModelWithProjection, - RawImage, -} from "@xenova/transformers"; import _ from "lodash"; +import { ClientOptions as OpenAIClientOptions } from "openai"; import { DEFAULT_SIMILARITY_TOP_K } from "./constants"; import { AzureOpenAIConfig, @@ -308,6 +300,7 @@ export class OpenAIEmbedding extends BaseEmbedding { export type ImageType = string | Blob | URL; async function readImage(input: ImageType) { + const { RawImage } = await import("@xenova/transformers"); if (input instanceof Blob) { return await RawImage.fromBlob(input); } else if (_.isString(input) || input instanceof URL) { @@ -320,7 +313,7 @@ async function readImage(input: ImageType) { /* * Base class for Multi Modal embeddings. */ -abstract class MultiModalEmbedding extends BaseEmbedding { +export abstract class MultiModalEmbedding extends BaseEmbedding { abstract getImageEmbedding(images: ImageType): Promise<number[]>; async getImageEmbeddings(images: ImageType[]): Promise<number[][]> { @@ -347,6 +340,7 @@ export class ClipEmbedding extends MultiModalEmbedding { async getTokenizer() { if (!this.tokenizer) { + const { AutoTokenizer } = await import("@xenova/transformers"); this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType); } return this.tokenizer; @@ -354,6 +348,7 @@ export class ClipEmbedding extends MultiModalEmbedding { async getProcessor() { if (!this.processor) { + const { AutoProcessor } = await import("@xenova/transformers"); this.processor = await AutoProcessor.from_pretrained(this.modelType); } return this.processor; @@ -361,6 +356,9 @@ export class ClipEmbedding extends MultiModalEmbedding { async getVisionModel() { if (!this.visionModel) { + const { CLIPVisionModelWithProjection } = await import( + "@xenova/transformers" + ); this.visionModel = await CLIPVisionModelWithProjection.from_pretrained( this.modelType, ); @@ -371,6 +369,9 @@ export class ClipEmbedding extends MultiModalEmbedding { async getTextModel() { if (!this.textModel) { + const { CLIPTextModelWithProjection } = await import( + "@xenova/transformers" + ); this.textModel = await CLIPTextModelWithProjection.from_pretrained( this.modelType, ); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 95795125d166b5218bf90992da88518c170fd8aa..d556d651f08a2c9274eec1c8a46c374b436f6f84 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -57,6 +57,13 @@ importers: llamaindex: specifier: workspace:* version: link:../../packages/core + devDependencies: + '@types/node': + specifier: ^18 + version: 18.18.8 + ts-node: + specifier: ^10.9.1 + version: 10.9.1(@types/node@18.18.8)(typescript@5.2.2) apps/docs: dependencies: