From bb917f9818e514d70e1923972d48477987c2d298 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Tue, 21 Nov 2023 14:20:10 +0700
Subject: [PATCH] refactor: moved embeddings to embeddings folder

---
 .vscode/settings.json                         |   5 +-
 packages/core/src/ServiceContext.ts           |   2 +-
 packages/core/src/embeddings/ClipEmbedding.ts |  78 +++++++
 .../src/embeddings/MultiModalEmbedding.ts     |  17 ++
 .../core/src/embeddings/OpenAIEmbedding.ts    |  92 ++++++++
 packages/core/src/embeddings/index.ts         |   5 +
 packages/core/src/embeddings/types.ts         |  24 ++
 .../src/{Embedding.ts => embeddings/utils.ts} | 219 +-----------------
 packages/core/src/index.ts                    |   2 +-
 .../storage/vectorStore/SimpleVectorStore.ts  |   4 +-
 .../core/src/tests/CallbackManager.test.ts    |   2 +-
 packages/core/src/tests/Embedding.test.ts     |   2 +-
 packages/core/src/tests/utility/mockOpenAI.ts |   2 +-
 13 files changed, 233 insertions(+), 221 deletions(-)
 create mode 100644 packages/core/src/embeddings/ClipEmbedding.ts
 create mode 100644 packages/core/src/embeddings/MultiModalEmbedding.ts
 create mode 100644 packages/core/src/embeddings/OpenAIEmbedding.ts
 create mode 100644 packages/core/src/embeddings/index.ts
 create mode 100644 packages/core/src/embeddings/types.ts
 rename packages/core/src/{Embedding.ts => embeddings/utils.ts} (50%)

diff --git a/.vscode/settings.json b/.vscode/settings.json
index d3a0c1169..9f6017380 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -4,5 +4,6 @@
   "editor.defaultFormatter": "esbenp.prettier-vscode",
   "[xml]": {
     "editor.defaultFormatter": "redhat.vscode-xml"
-  }
-}
+  },
+  "jest.rootPath": "./packages/core"
+}
\ No newline at end of file
diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts
index efefd319b..da33be209 100644
--- a/packages/core/src/ServiceContext.ts
+++ b/packages/core/src/ServiceContext.ts
@@ -1,7 +1,7 @@
-import { BaseEmbedding, OpenAIEmbedding } from "./Embedding";
 import { NodeParser, SimpleNodeParser } from "./NodeParser";
 import { PromptHelper } from "./PromptHelper";
 import { CallbackManager } from "./callbacks/CallbackManager";
+import { BaseEmbedding, OpenAIEmbedding } from "./embeddings";
 import { LLM, OpenAI } from "./llm/LLM";
 
 /**
diff --git a/packages/core/src/embeddings/ClipEmbedding.ts b/packages/core/src/embeddings/ClipEmbedding.ts
new file mode 100644
index 000000000..b75b4b879
--- /dev/null
+++ b/packages/core/src/embeddings/ClipEmbedding.ts
@@ -0,0 +1,78 @@
+import { MultiModalEmbedding } from "./MultiModalEmbedding";
+import { ImageType, readImage } from "./utils";
+
+export enum ClipEmbeddingModelType {
+  XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32",
+  XENOVA_CLIP_VIT_BASE_PATCH16 = "Xenova/clip-vit-base-patch16",
+}
+
+export class ClipEmbedding extends MultiModalEmbedding {
+  modelType: ClipEmbeddingModelType =
+    ClipEmbeddingModelType.XENOVA_CLIP_VIT_BASE_PATCH16;
+
+  private tokenizer: any;
+  private processor: any;
+  private visionModel: any;
+  private textModel: any;
+
+  async getTokenizer() {
+    if (!this.tokenizer) {
+      const { AutoTokenizer } = await import("@xenova/transformers");
+      this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType);
+    }
+    return this.tokenizer;
+  }
+
+  async getProcessor() {
+    if (!this.processor) {
+      const { AutoProcessor } = await import("@xenova/transformers");
+      this.processor = await AutoProcessor.from_pretrained(this.modelType);
+    }
+    return this.processor;
+  }
+
+  async getVisionModel() {
+    if (!this.visionModel) {
+      const { CLIPVisionModelWithProjection } = await import(
+        "@xenova/transformers"
+      );
+      this.visionModel = await CLIPVisionModelWithProjection.from_pretrained(
+        this.modelType,
+      );
+    }
+
+    return this.visionModel;
+  }
+
+  async getTextModel() {
+    if (!this.textModel) {
+      const { CLIPTextModelWithProjection } = await import(
+        "@xenova/transformers"
+      );
+      this.textModel = await CLIPTextModelWithProjection.from_pretrained(
+        this.modelType,
+      );
+    }
+
+    return this.textModel;
+  }
+
+  async getImageEmbedding(image: ImageType): Promise<number[]> {
+    const loadedImage = await readImage(image);
+    const imageInputs = await (await this.getProcessor())(loadedImage);
+    const { image_embeds } = await (await this.getVisionModel())(imageInputs);
+    return image_embeds.data;
+  }
+
+  async getTextEmbedding(text: string): Promise<number[]> {
+    const textInputs = await (
+      await this.getTokenizer()
+    )([text], { padding: true, truncation: true });
+    const { text_embeds } = await (await this.getTextModel())(textInputs);
+    return text_embeds.data;
+  }
+
+  async getQueryEmbedding(query: string): Promise<number[]> {
+    return this.getTextEmbedding(query);
+  }
+}
diff --git a/packages/core/src/embeddings/MultiModalEmbedding.ts b/packages/core/src/embeddings/MultiModalEmbedding.ts
new file mode 100644
index 000000000..c86ba0721
--- /dev/null
+++ b/packages/core/src/embeddings/MultiModalEmbedding.ts
@@ -0,0 +1,17 @@
+import { BaseEmbedding } from "./types";
+import { ImageType } from "./utils";
+
+/*
+ * Base class for Multi Modal embeddings.
+ */
+
+export abstract class MultiModalEmbedding extends BaseEmbedding {
+  abstract getImageEmbedding(images: ImageType): Promise<number[]>;
+
+  async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
+    // Embed the input sequence of images asynchronously.
+    return Promise.all(
+      images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)),
+    );
+  }
+}
diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts
new file mode 100644
index 000000000..106c6cbff
--- /dev/null
+++ b/packages/core/src/embeddings/OpenAIEmbedding.ts
@@ -0,0 +1,92 @@
+import { ClientOptions as OpenAIClientOptions } from "openai";
+import {
+  AzureOpenAIConfig,
+  getAzureBaseUrl,
+  getAzureConfigFromEnv,
+  getAzureModel,
+  shouldUseAzure,
+} from "../llm/azure";
+import { OpenAISession, getOpenAISession } from "../llm/openai";
+import { BaseEmbedding } from "./types";
+
+export enum OpenAIEmbeddingModelType {
+  TEXT_EMBED_ADA_002 = "text-embedding-ada-002",
+}
+
+export class OpenAIEmbedding extends BaseEmbedding {
+  model: OpenAIEmbeddingModelType;
+
+  // OpenAI session params
+  apiKey?: string = undefined;
+  maxRetries: number;
+  timeout?: number;
+  additionalSessionOptions?: Omit<
+    Partial<OpenAIClientOptions>,
+    "apiKey" | "maxRetries" | "timeout"
+  >;
+
+  session: OpenAISession;
+
+  constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) {
+    super();
+
+    this.model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002;
+
+    this.maxRetries = init?.maxRetries ?? 10;
+    this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
+    this.additionalSessionOptions = init?.additionalSessionOptions;
+
+    if (init?.azure || shouldUseAzure()) {
+      const azureConfig = getAzureConfigFromEnv({
+        ...init?.azure,
+        model: getAzureModel(this.model),
+      });
+
+      if (!azureConfig.apiKey) {
+        throw new Error(
+          "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
+        );
+      }
+
+      this.apiKey = azureConfig.apiKey;
+      this.session =
+        init?.session ??
+        getOpenAISession({
+          azure: true,
+          apiKey: this.apiKey,
+          baseURL: getAzureBaseUrl(azureConfig),
+          maxRetries: this.maxRetries,
+          timeout: this.timeout,
+          defaultQuery: { "api-version": azureConfig.apiVersion },
+          ...this.additionalSessionOptions,
+        });
+    } else {
+      this.apiKey = init?.apiKey ?? undefined;
+      this.session =
+        init?.session ??
+        getOpenAISession({
+          apiKey: this.apiKey,
+          maxRetries: this.maxRetries,
+          timeout: this.timeout,
+          ...this.additionalSessionOptions,
+        });
+    }
+  }
+
+  private async getOpenAIEmbedding(input: string) {
+    const { data } = await this.session.openai.embeddings.create({
+      model: this.model,
+      input,
+    });
+
+    return data[0].embedding;
+  }
+
+  async getTextEmbedding(text: string): Promise<number[]> {
+    return this.getOpenAIEmbedding(text);
+  }
+
+  async getQueryEmbedding(query: string): Promise<number[]> {
+    return this.getOpenAIEmbedding(query);
+  }
+}
diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts
new file mode 100644
index 000000000..1a6a4df04
--- /dev/null
+++ b/packages/core/src/embeddings/index.ts
@@ -0,0 +1,5 @@
+export * from "./ClipEmbedding";
+export * from "./MultiModalEmbedding";
+export * from "./OpenAIEmbedding";
+export * from "./types";
+export * from "./utils";
diff --git a/packages/core/src/embeddings/types.ts b/packages/core/src/embeddings/types.ts
new file mode 100644
index 000000000..e500f9452
--- /dev/null
+++ b/packages/core/src/embeddings/types.ts
@@ -0,0 +1,24 @@
+import { similarity } from "./utils";
+
+/**
+ * Similarity type
+ * Default is cosine similarity. Dot product and negative Euclidean distance are also supported.
+ */
+export enum SimilarityType {
+  DEFAULT = "cosine",
+  DOT_PRODUCT = "dot_product",
+  EUCLIDEAN = "euclidean",
+}
+
+export abstract class BaseEmbedding {
+  similarity(
+    embedding1: number[],
+    embedding2: number[],
+    mode: SimilarityType = SimilarityType.DEFAULT,
+  ): number {
+    return similarity(embedding1, embedding2, mode);
+  }
+
+  abstract getTextEmbedding(text: string): Promise<number[]>;
+  abstract getQueryEmbedding(query: string): Promise<number[]>;
+}
diff --git a/packages/core/src/Embedding.ts b/packages/core/src/embeddings/utils.ts
similarity index 50%
rename from packages/core/src/Embedding.ts
rename to packages/core/src/embeddings/utils.ts
index ad6650251..cd192c3d4 100644
--- a/packages/core/src/Embedding.ts
+++ b/packages/core/src/embeddings/utils.ts
@@ -1,33 +1,16 @@
 import _ from "lodash";
-import { ClientOptions as OpenAIClientOptions } from "openai";
-import { DEFAULT_SIMILARITY_TOP_K } from "./constants";
-import {
-  AzureOpenAIConfig,
-  getAzureBaseUrl,
-  getAzureConfigFromEnv,
-  getAzureModel,
-  shouldUseAzure,
-} from "./llm/azure";
-import { OpenAISession, getOpenAISession } from "./llm/openai";
-import { VectorStoreQueryMode } from "./storage/vectorStore/types";
-
-/**
- * Similarity type
- * Default is cosine similarity. Dot product and negative Euclidean distance are also supported.
- */
-export enum SimilarityType {
-  DEFAULT = "cosine",
-  DOT_PRODUCT = "dot_product",
-  EUCLIDEAN = "euclidean",
-}
+import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
+import { VectorStoreQueryMode } from "../storage";
+import { SimilarityType } from "./types";
 
 /**
  * The similarity between two embeddings.
  * @param embedding1
  * @param embedding2
  * @param mode
- * @returns similartiy score with higher numbers meaning the two embeddings are more similar
+ * @returns similarity score with higher numbers meaning the two embeddings are more similar
  */
+
 export function similarity(
   embedding1: number[],
   embedding2: number[],
@@ -42,7 +25,6 @@ export function similarity(
   // will probably cause some avoidable loss of floating point precision
   // ml-distance is worth watching although they currently also use the naive
   // formulas
-
   function norm(x: number[]): number {
     let result = 0;
     for (let i = 0; i < x.length; i++) {
@@ -201,105 +183,7 @@ export function getTopKMMREmbeddings(
 
   return [resultSimilarities, resultIds];
 }
-
-export abstract class BaseEmbedding {
-  similarity(
-    embedding1: number[],
-    embedding2: number[],
-    mode: SimilarityType = SimilarityType.DEFAULT,
-  ): number {
-    return similarity(embedding1, embedding2, mode);
-  }
-
-  abstract getTextEmbedding(text: string): Promise<number[]>;
-  abstract getQueryEmbedding(query: string): Promise<number[]>;
-}
-
-enum OpenAIEmbeddingModelType {
-  TEXT_EMBED_ADA_002 = "text-embedding-ada-002",
-}
-
-export class OpenAIEmbedding extends BaseEmbedding {
-  model: OpenAIEmbeddingModelType;
-
-  // OpenAI session params
-  apiKey?: string = undefined;
-  maxRetries: number;
-  timeout?: number;
-  additionalSessionOptions?: Omit<
-    Partial<OpenAIClientOptions>,
-    "apiKey" | "maxRetries" | "timeout"
-  >;
-
-  session: OpenAISession;
-
-  constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) {
-    super();
-
-    this.model = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002;
-
-    this.maxRetries = init?.maxRetries ?? 10;
-    this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
-    this.additionalSessionOptions = init?.additionalSessionOptions;
-
-    if (init?.azure || shouldUseAzure()) {
-      const azureConfig = getAzureConfigFromEnv({
-        ...init?.azure,
-        model: getAzureModel(this.model),
-      });
-
-      if (!azureConfig.apiKey) {
-        throw new Error(
-          "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
-        );
-      }
-
-      this.apiKey = azureConfig.apiKey;
-      this.session =
-        init?.session ??
-        getOpenAISession({
-          azure: true,
-          apiKey: this.apiKey,
-          baseURL: getAzureBaseUrl(azureConfig),
-          maxRetries: this.maxRetries,
-          timeout: this.timeout,
-          defaultQuery: { "api-version": azureConfig.apiVersion },
-          ...this.additionalSessionOptions,
-        });
-    } else {
-      this.apiKey = init?.apiKey ?? undefined;
-      this.session =
-        init?.session ??
-        getOpenAISession({
-          apiKey: this.apiKey,
-          maxRetries: this.maxRetries,
-          timeout: this.timeout,
-          ...this.additionalSessionOptions,
-        });
-    }
-  }
-
-  private async getOpenAIEmbedding(input: string) {
-    const { data } = await this.session.openai.embeddings.create({
-      model: this.model,
-      input,
-    });
-
-    return data[0].embedding;
-  }
-
-  async getTextEmbedding(text: string): Promise<number[]> {
-    return this.getOpenAIEmbedding(text);
-  }
-
-  async getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getOpenAIEmbedding(query);
-  }
-}
-
-export type ImageType = string | Blob | URL;
-
-async function readImage(input: ImageType) {
+export async function readImage(input: ImageType) {
   const { RawImage } = await import("@xenova/transformers");
   if (input instanceof Blob) {
     return await RawImage.fromBlob(input);
@@ -309,93 +193,4 @@ async function readImage(input: ImageType) {
     throw new Error(`Unsupported input type: ${typeof input}`);
   }
 }
-
-/*
- * Base class for Multi Modal embeddings.
- */
-export abstract class MultiModalEmbedding extends BaseEmbedding {
-  abstract getImageEmbedding(images: ImageType): Promise<number[]>;
-
-  async getImageEmbeddings(images: ImageType[]): Promise<number[][]> {
-    // Embed the input sequence of images asynchronously.
-    return Promise.all(
-      images.map((imgFilePath) => this.getImageEmbedding(imgFilePath)),
-    );
-  }
-}
-
-enum ClipEmbeddingModelType {
-  XENOVA_CLIP_VIT_BASE_PATCH32 = "Xenova/clip-vit-base-patch32",
-  XENOVA_CLIP_VIT_BASE_PATCH16 = "Xenova/clip-vit-base-patch16",
-}
-
-export class ClipEmbedding extends MultiModalEmbedding {
-  modelType: ClipEmbeddingModelType =
-    ClipEmbeddingModelType.XENOVA_CLIP_VIT_BASE_PATCH16;
-
-  private tokenizer: any;
-  private processor: any;
-  private visionModel: any;
-  private textModel: any;
-
-  async getTokenizer() {
-    if (!this.tokenizer) {
-      const { AutoTokenizer } = await import("@xenova/transformers");
-      this.tokenizer = await AutoTokenizer.from_pretrained(this.modelType);
-    }
-    return this.tokenizer;
-  }
-
-  async getProcessor() {
-    if (!this.processor) {
-      const { AutoProcessor } = await import("@xenova/transformers");
-      this.processor = await AutoProcessor.from_pretrained(this.modelType);
-    }
-    return this.processor;
-  }
-
-  async getVisionModel() {
-    if (!this.visionModel) {
-      const { CLIPVisionModelWithProjection } = await import(
-        "@xenova/transformers"
-      );
-      this.visionModel = await CLIPVisionModelWithProjection.from_pretrained(
-        this.modelType,
-      );
-    }
-
-    return this.visionModel;
-  }
-
-  async getTextModel() {
-    if (!this.textModel) {
-      const { CLIPTextModelWithProjection } = await import(
-        "@xenova/transformers"
-      );
-      this.textModel = await CLIPTextModelWithProjection.from_pretrained(
-        this.modelType,
-      );
-    }
-
-    return this.textModel;
-  }
-
-  async getImageEmbedding(image: ImageType): Promise<number[]> {
-    const loadedImage = await readImage(image);
-    const imageInputs = await (await this.getProcessor())(loadedImage);
-    const { image_embeds } = await (await this.getVisionModel())(imageInputs);
-    return image_embeds.data;
-  }
-
-  async getTextEmbedding(text: string): Promise<number[]> {
-    const textInputs = await (
-      await this.getTokenizer()
-    )([text], { padding: true, truncation: true });
-    const { text_embeds } = await (await this.getTextModel())(textInputs);
-    return text_embeds.data;
-  }
-
-  async getQueryEmbedding(query: string): Promise<number[]> {
-    return this.getTextEmbedding(query);
-  }
-}
+export type ImageType = string | Blob | URL;
diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts
index 20ab46297..dde8fff26 100644
--- a/packages/core/src/index.ts
+++ b/packages/core/src/index.ts
@@ -1,6 +1,5 @@
 export * from "./ChatEngine";
 export * from "./ChatHistory";
-export * from "./Embedding";
 export * from "./GlobalsHelper";
 export * from "./Node";
 export * from "./NodeParser";
@@ -17,6 +16,7 @@ export * from "./TextSplitter";
 export * from "./Tool";
 export * from "./callbacks/CallbackManager";
 export * from "./constants";
+export * from "./embeddings";
 export * from "./indices";
 export * from "./llm/LLM";
 export * from "./readers/CSVReader";
diff --git a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts
index 1bccf12ff..929ebe2c2 100644
--- a/packages/core/src/storage/vectorStore/SimpleVectorStore.ts
+++ b/packages/core/src/storage/vectorStore/SimpleVectorStore.ts
@@ -1,11 +1,11 @@
 import _ from "lodash";
 import * as path from "path";
+import { BaseNode } from "../../Node";
 import {
   getTopKEmbeddings,
   getTopKEmbeddingsLearner,
   getTopKMMREmbeddings,
-} from "../../Embedding";
-import { BaseNode } from "../../Node";
+} from "../../embeddings";
 import { GenericFileSystem, exists } from "../FileSystem";
 import { DEFAULT_FS, DEFAULT_PERSIST_DIR } from "../constants";
 import {
diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts
index c3d9a98d4..9374c50b6 100644
--- a/packages/core/src/tests/CallbackManager.test.ts
+++ b/packages/core/src/tests/CallbackManager.test.ts
@@ -1,4 +1,3 @@
-import { OpenAIEmbedding } from "../Embedding";
 import { Document } from "../Node";
 import {
   ResponseSynthesizer,
@@ -10,6 +9,7 @@ import {
   RetrievalCallbackResponse,
   StreamCallbackResponse,
 } from "../callbacks/CallbackManager";
+import { OpenAIEmbedding } from "../embeddings";
 import { SummaryIndex } from "../indices/summary";
 import { VectorStoreIndex } from "../indices/vectorStore/VectorStoreIndex";
 import { OpenAI } from "../llm/LLM";
diff --git a/packages/core/src/tests/Embedding.test.ts b/packages/core/src/tests/Embedding.test.ts
index 492a48be1..adc70810f 100644
--- a/packages/core/src/tests/Embedding.test.ts
+++ b/packages/core/src/tests/Embedding.test.ts
@@ -1,4 +1,4 @@
-import { SimilarityType, similarity } from "../Embedding";
+import { SimilarityType, similarity } from "../embeddings";
 
 describe("similarity", () => {
   test("throws error on mismatched lengths", () => {
diff --git a/packages/core/src/tests/utility/mockOpenAI.ts b/packages/core/src/tests/utility/mockOpenAI.ts
index 04d6cdd40..5f8a0ad4c 100644
--- a/packages/core/src/tests/utility/mockOpenAI.ts
+++ b/packages/core/src/tests/utility/mockOpenAI.ts
@@ -1,6 +1,6 @@
-import { OpenAIEmbedding } from "../../Embedding";
 import { globalsHelper } from "../../GlobalsHelper";
 import { CallbackManager, Event } from "../../callbacks/CallbackManager";
+import { OpenAIEmbedding } from "../../embeddings";
 import { ChatMessage, OpenAI } from "../../llm/LLM";
 
 export function mockLlmGeneration({
-- 
GitLab