From e72fa8b370212394f9d52d7dc629b568de6538c9 Mon Sep 17 00:00:00 2001
From: Sean Hatfield <seanhatfield5@gmail.com>
Date: Fri, 21 Jun 2024 16:27:02 -0700
Subject: [PATCH] [FEAT] Generic OpenAI embedding provider (#1664)

* implement generic openai embedding provider

* linting

* comment & description update for generic openai embedding provider

* fix privacy for generic

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
---
 docker/.env.example                           |  6 ++
 .../GenericOpenAiOptions/index.jsx            | 74 +++++++++++++++
 .../EmbeddingPreference/index.jsx             | 11 +++
 .../Steps/DataHandling/index.jsx              |  7 ++
 server/.env.example                           |  6 ++
 server/models/systemSettings.js               |  2 +
 .../EmbeddingEngines/genericOpenAi/index.js   | 95 +++++++++++++++++++
 server/utils/helpers/index.js                 |  5 +
 server/utils/helpers/updateENV.js             |  7 ++
 9 files changed, 213 insertions(+)
 create mode 100644 frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx
 create mode 100644 server/utils/EmbeddingEngines/genericOpenAi/index.js

diff --git a/docker/.env.example b/docker/.env.example
index f682f8bf1..38b980880 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -136,6 +136,12 @@ GID='1000'
 # LITE_LLM_BASE_PATH='http://127.0.0.1:4000'
 # LITE_LLM_API_KEY='sk-123abc'
 
+# EMBEDDING_ENGINE='generic-openai'
+# EMBEDDING_MODEL_PREF='text-embedding-ada-002'
+# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192
+# EMBEDDING_BASE_PATH='http://127.0.0.1:4000'
+# GENERIC_OPEN_AI_EMBEDDING_API_KEY='sk-123abc'
+
 ###########################################
 ######## Vector Database Selection ########
 ###########################################
diff --git a/frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx b/frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx
new file mode 100644
index 000000000..8d4870f07
--- /dev/null
+++ b/frontend/src/components/EmbeddingSelection/GenericOpenAiOptions/index.jsx
@@ -0,0 +1,74 @@
+export default function GenericOpenAiEmbeddingOptions({ settings }) {
+  return (
+    <div className="w-full flex flex-col gap-y-4">
+      <div className="w-full flex items-center gap-4 flex-wrap">
+        <div className="flex flex-col w-60">
+          <label className="text-white text-sm font-semibold block mb-4">
+            Base URL
+          </label>
+          <input
+            type="url"
+            name="EmbeddingBasePath"
+            className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:border-white block w-full p-2.5"
+            placeholder="https://api.openai.com/v1"
+            defaultValue={settings?.EmbeddingBasePath}
+            required={true}
+            autoComplete="off"
+            spellCheck={false}
+          />
+        </div>
+        <div className="flex flex-col w-60">
+          <label className="text-white text-sm font-semibold block mb-4">
+            Embedding Model
+          </label>
+          <input
+            type="text"
+            name="EmbeddingModelPref"
+            className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:border-white block w-full p-2.5"
+            placeholder="text-embedding-ada-002"
+            defaultValue={settings?.EmbeddingModelPref}
+            required={true}
+            autoComplete="off"
+            spellCheck={false}
+          />
+        </div>
+        <div className="flex flex-col w-60">
+          <label className="text-white text-sm font-semibold block mb-4">
+            Max embedding chunk length
+          </label>
+          <input
+            type="number"
+            name="EmbeddingModelMaxChunkLength"
+            className="bg-zinc-900 text-white placeholder-white/20 text-sm rounded-lg focus:border-white block w-full p-2.5"
+            placeholder="8192"
+            min={1}
+            onScroll={(e) => e.target.blur()}
+            defaultValue={settings?.EmbeddingModelMaxChunkLength}
+            required={false}
+            autoComplete="off"
+          />
+        </div>
+      </div>
+      <div className="w-full flex items-center gap-4">
+        <div className="flex flex-col w-60">
+          <div className="flex flex-col gap-y-1 mb-4">
+            <label className="text-white text-sm font-semibold flex items-center gap-x-2">
+              API Key <p className="!text-xs !italic !font-thin">optional</p>
+            </label>
+          </div>
+          <input
+            type="password"
+            name="GenericOpenAiEmbeddingApiKey"
+            className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:border-white block w-full p-2.5"
+            placeholder="sk-mysecretkey"
+            defaultValue={
+              settings?.GenericOpenAiEmbeddingApiKey ? "*".repeat(20) : ""
+            }
+            autoComplete="off"
+            spellCheck={false}
+          />
+        </div>
+      </div>
+    </div>
+  );
+}
diff --git a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx
index 2563aaadb..ec8c2b8bd 100644
--- a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx
+++ b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx
@@ -12,6 +12,7 @@ import LMStudioLogo from "@/media/llmprovider/lmstudio.png";
 import CohereLogo from "@/media/llmprovider/cohere.png";
 import VoyageAiLogo from "@/media/embeddingprovider/voyageai.png";
 import LiteLLMLogo from "@/media/llmprovider/litellm.png";
+import GenericOpenAiLogo from "@/media/llmprovider/generic-openai.png";
 
 import PreLoader from "@/components/Preloader";
 import ChangeWarningModal from "@/components/ChangeWarning";
@@ -24,6 +25,7 @@ import LMStudioEmbeddingOptions from "@/components/EmbeddingSelection/LMStudioOp
 import CohereEmbeddingOptions from "@/components/EmbeddingSelection/CohereOptions";
 import VoyageAiOptions from "@/components/EmbeddingSelection/VoyageAiOptions";
 import LiteLLMOptions from "@/components/EmbeddingSelection/LiteLLMOptions";
+import GenericOpenAiEmbeddingOptions from "@/components/EmbeddingSelection/GenericOpenAiOptions";
 
 import EmbedderItem from "@/components/EmbeddingSelection/EmbedderItem";
 import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react";
@@ -98,6 +100,15 @@ const EMBEDDERS = [
     options: (settings) => <LiteLLMOptions settings={settings} />,
     description: "Run powerful embedding models from LiteLLM.",
   },
+  {
+    name: "Generic OpenAI",
+    value: "generic-openai",
+    logo: GenericOpenAiLogo,
+    options: (settings) => (
+      <GenericOpenAiEmbeddingOptions settings={settings} />
+    ),
+    description: "Run embedding models from any OpenAI compatible API service.",
+  },
 ];
 
 export default function GeneralEmbeddingPreference() {
diff --git a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx
index b4fa666ff..1b3bf360b 100644
--- a/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx
+++ b/frontend/src/pages/OnboardingFlow/Steps/DataHandling/index.jsx
@@ -308,6 +308,13 @@ export const EMBEDDING_ENGINE_PRIVACY = {
     ],
     logo: LiteLLMLogo,
   },
+  "generic-openai": {
+    name: "Generic OpenAI compatible service",
+    description: [
+      "Data is shared according to the terms of service applicable with your generic endpoint provider.",
+    ],
+    logo: GenericOpenAiLogo,
+  },
 };
 
 export default function DataHandling({ setHeader, setForwardBtn, setBackBtn }) {
diff --git a/server/.env.example b/server/.env.example
index 145e00da1..22bd557ee 100644
--- a/server/.env.example
+++ b/server/.env.example
@@ -133,6 +133,12 @@ SIG_SALT='salt' # Please generate random string at least 32 chars long.
 # LITE_LLM_BASE_PATH='http://127.0.0.1:4000'
 # LITE_LLM_API_KEY='sk-123abc'
 
+# EMBEDDING_ENGINE='generic-openai'
+# EMBEDDING_MODEL_PREF='text-embedding-ada-002'
+# EMBEDDING_MODEL_MAX_CHUNK_LENGTH=8192
+# EMBEDDING_BASE_PATH='http://127.0.0.1:4000'
+# GENERIC_OPEN_AI_EMBEDDING_API_KEY='sk-123abc'
+
 ###########################################
 ######## Vector Database Selection ########
 ###########################################
diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js
index eae75d9ca..3f44f7228 100644
--- a/server/models/systemSettings.js
+++ b/server/models/systemSettings.js
@@ -149,6 +149,8 @@ const SystemSettings = {
       EmbeddingModelPref: process.env.EMBEDDING_MODEL_PREF,
       EmbeddingModelMaxChunkLength:
         process.env.EMBEDDING_MODEL_MAX_CHUNK_LENGTH,
+      GenericOpenAiEmbeddingApiKey:
+        !!process.env.GENERIC_OPEN_AI_EMBEDDING_API_KEY,
 
       // --------------------------------------------------------
       // VectorDB Provider Selection Settings & Configs
diff --git a/server/utils/EmbeddingEngines/genericOpenAi/index.js b/server/utils/EmbeddingEngines/genericOpenAi/index.js
new file mode 100644
index 000000000..d3ec30721
--- /dev/null
+++ b/server/utils/EmbeddingEngines/genericOpenAi/index.js
@@ -0,0 +1,95 @@
+const { toChunks } = require("../../helpers");
+
+class GenericOpenAiEmbedder {
+  constructor() {
+    if (!process.env.EMBEDDING_BASE_PATH)
+      throw new Error(
+        "GenericOpenAI must have a valid base path to use for the api."
+      );
+    const { OpenAI: OpenAIApi } = require("openai");
+    this.basePath = process.env.EMBEDDING_BASE_PATH;
+    this.openai = new OpenAIApi({
+      baseURL: this.basePath,
+      apiKey: process.env.GENERIC_OPEN_AI_EMBEDDING_API_KEY ?? null,
+    });
+    this.model = process.env.EMBEDDING_MODEL_PREF ?? null;
+
+    // Limit of how many strings we can process in a single pass to stay with resource or network limits
+    this.maxConcurrentChunks = 500;
+
+    // Refer to your specific model and provider you use this class with to determine a valid maxChunkLength
+    this.embeddingMaxChunkLength = 8_191;
+  }
+
+  async embedTextInput(textInput) {
+    const result = await this.embedChunks(
+      Array.isArray(textInput) ? textInput : [textInput]
+    );
+    return result?.[0] || [];
+  }
+
+  async embedChunks(textChunks = []) {
+    // Because there is a hard POST limit on how many chunks can be sent at once to OpenAI (~8mb)
+    // we concurrently execute each max batch of text chunks possible.
+    // Refer to constructor maxConcurrentChunks for more info.
+    const embeddingRequests = [];
+    for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) {
+      embeddingRequests.push(
+        new Promise((resolve) => {
+          this.openai.embeddings
+            .create({
+              model: this.model,
+              input: chunk,
+            })
+            .then((result) => {
+              resolve({ data: result?.data, error: null });
+            })
+            .catch((e) => {
+              e.type =
+                e?.response?.data?.error?.code ||
+                e?.response?.status ||
+                "failed_to_embed";
+              e.message = e?.response?.data?.error?.message || e.message;
+              resolve({ data: [], error: e });
+            });
+        })
+      );
+    }
+
+    const { data = [], error = null } = await Promise.all(
+      embeddingRequests
+    ).then((results) => {
+      // If any errors were returned from OpenAI abort the entire sequence because the embeddings
+      // will be incomplete.
+      const errors = results
+        .filter((res) => !!res.error)
+        .map((res) => res.error)
+        .flat();
+      if (errors.length > 0) {
+        let uniqueErrors = new Set();
+        errors.map((error) =>
+          uniqueErrors.add(`[${error.type}]: ${error.message}`)
+        );
+
+        return {
+          data: [],
+          error: Array.from(uniqueErrors).join(", "),
+        };
+      }
+      return {
+        data: results.map((res) => res?.data || []).flat(),
+        error: null,
+      };
+    });
+
+    if (!!error) throw new Error(`GenericOpenAI Failed to embed: ${error}`);
+    return data.length > 0 &&
+      data.every((embd) => embd.hasOwnProperty("embedding"))
+      ? data.map((embd) => embd.embedding)
+      : null;
+  }
+}
+
+module.exports = {
+  GenericOpenAiEmbedder,
+};
diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js
index 8f0df1264..302ec9589 100644
--- a/server/utils/helpers/index.js
+++ b/server/utils/helpers/index.js
@@ -131,6 +131,11 @@ function getEmbeddingEngineSelection() {
     case "litellm":
       const { LiteLLMEmbedder } = require("../EmbeddingEngines/liteLLM");
       return new LiteLLMEmbedder();
+    case "generic-openai":
+      const {
+        GenericOpenAiEmbedder,
+      } = require("../EmbeddingEngines/genericOpenAi");
+      return new GenericOpenAiEmbedder();
     default:
       return new NativeEmbedder();
   }
diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js
index 6abd6408b..c2cfc1aa8 100644
--- a/server/utils/helpers/updateENV.js
+++ b/server/utils/helpers/updateENV.js
@@ -221,6 +221,12 @@ const KEY_MAPPING = {
     checks: [nonZero],
   },
 
+  // Generic OpenAI Embedding Settings
+  GenericOpenAiEmbeddingApiKey: {
+    envKey: "GENERIC_OPEN_AI_EMBEDDING_API_KEY",
+    checks: [],
+  },
+
   // Vector Database Selection Settings
   VectorDB: {
     envKey: "VECTOR_DB",
@@ -587,6 +593,7 @@ function supportedEmbeddingModel(input = "") {
     "cohere",
     "voyageai",
     "litellm",
+    "generic-openai",
   ];
   return supported.includes(input)
     ? null
-- 
GitLab