From 1a5aacb001259fa98fbb97548d4ce0c8410cd470 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Fri, 17 May 2024 21:31:29 -0700
Subject: [PATCH] Support multi-model whispers (#1444)

---
 .../utils/WhisperProviders/localWhisper.js    | 16 ++--
 .../NativeTranscriptionOptions/index.jsx      | 93 ++++++++++++++-----
 .../TranscriptionPreference/index.jsx         | 37 ++++----
 server/models/systemSettings.js               |  2 +
 server/utils/collectorApi/index.js            |  1 +
 server/utils/helpers/updateENV.js             | 15 +++
 6 files changed, 118 insertions(+), 46 deletions(-)

diff --git a/collector/utils/WhisperProviders/localWhisper.js b/collector/utils/WhisperProviders/localWhisper.js
index 46dbe226b..af13c8a92 100644
--- a/collector/utils/WhisperProviders/localWhisper.js
+++ b/collector/utils/WhisperProviders/localWhisper.js
@@ -1,19 +1,23 @@
 const fs = require("fs");
 const path = require("path");
 const { v4 } = require("uuid");
+const defaultWhisper = "Xenova/whisper-small"; // Model Card: https://huggingface.co/Xenova/whisper-small
+const fileSize = {
+  "Xenova/whisper-small": "250mb",
+  "Xenova/whisper-large": "1.56GB",
+};
 
 class LocalWhisper {
-  constructor() {
-    // Model Card: https://huggingface.co/Xenova/whisper-small
-    this.model = "Xenova/whisper-small";
+  constructor({ options }) {
+    this.model = options?.WhisperModelPref ?? defaultWhisper;
+    this.fileSize = fileSize[this.model];
     this.cacheDir = path.resolve(
       process.env.STORAGE_DIR
         ? path.resolve(process.env.STORAGE_DIR, `models`)
         : path.resolve(__dirname, `../../../server/storage/models`)
     );
 
-    this.modelPath = path.resolve(this.cacheDir, "Xenova", "whisper-small");
-
+    this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
     // Make directory when it does not exist in existing installations
     if (!fs.existsSync(this.cacheDir))
       fs.mkdirSync(this.cacheDir, { recursive: true });
@@ -104,7 +108,7 @@ class LocalWhisper {
   async client() {
     if (!fs.existsSync(this.modelPath)) {
       this.#log(
-        `The native whisper model has never been run and will be downloaded right now. Subsequent runs will be faster. (~250MB)`
+        `The native whisper model has never been run and will be downloaded right now. Subsequent runs will be faster. (~${this.fileSize})`
       );
     }
 
diff --git a/frontend/src/components/TranscriptionSelection/NativeTranscriptionOptions/index.jsx b/frontend/src/components/TranscriptionSelection/NativeTranscriptionOptions/index.jsx
index 07ee12126..d2e81a68a 100644
--- a/frontend/src/components/TranscriptionSelection/NativeTranscriptionOptions/index.jsx
+++ b/frontend/src/components/TranscriptionSelection/NativeTranscriptionOptions/index.jsx
@@ -1,38 +1,89 @@
 import { Gauge } from "@phosphor-icons/react";
-export default function NativeTranscriptionOptions() {
+import { useState } from "react";
+
+export default function NativeTranscriptionOptions({ settings }) {
+  const [model, setModel] = useState(settings?.WhisperModelPref);
+
   return (
     <div className="w-full flex flex-col gap-y-4">
-      <div className="flex flex-col md:flex-row md:items-center gap-x-2 text-white mb-4 bg-blue-800/30 w-fit rounded-lg px-4 py-2">
-        <div className="gap-x-2 flex items-center">
-          <Gauge size={25} />
-          <p className="text-sm">
-            Using the local whisper model on machines with limited RAM or CPU
-            can stall AnythingLLM when processing media files.
-            <br />
-            We recommend at least 2GB of RAM and upload files &lt;10Mb.
-            <br />
-            <br />
-            <i>
-              The built-in model will automatically download on the first use.
-            </i>
-          </p>
-        </div>
-      </div>
+      <LocalWarning model={model} />
       <div className="w-full flex items-center gap-4">
         <div className="flex flex-col w-60">
           <label className="text-white text-sm font-semibold block mb-4">
             Model Selection
           </label>
           <select
-            disabled={true}
+            name="WhisperModelPref"
+            defaultValue={model}
+            onChange={(e) => setModel(e.target.value)}
             className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
           >
-            <option disabled={true} selected={true}>
-              Xenova/whisper-small
-            </option>
+            {["Xenova/whisper-small", "Xenova/whisper-large"].map(
+              (value, i) => {
+                return (
+                  <option key={i} value={value}>
+                    {value}
+                  </option>
+                );
+              }
+            )}
           </select>
         </div>
       </div>
     </div>
   );
 }
+
+function LocalWarning({ model }) {
+  switch (model) {
+    case "Xenova/whisper-small":
+      return <WhisperSmall />;
+    case "Xenova/whisper-large":
+      return <WhisperLarge />;
+    default:
+      return <WhisperSmall />;
+  }
+}
+
+function WhisperSmall() {
+  return (
+    <div className="flex flex-col md:flex-row md:items-center gap-x-2 text-white mb-4 bg-blue-800/30 w-fit rounded-lg px-4 py-2">
+      <div className="gap-x-2 flex items-center">
+        <Gauge size={25} />
+        <p className="text-sm">
+          Running the <b>whisper-small</b> model on a machine with limited RAM
+          or CPU can stall AnythingLLM when processing media files.
+          <br />
+          We recommend at least 2GB of RAM and upload files &lt;10Mb.
+          <br />
+          <br />
+          <i>
+            This model will automatically download on the first use. (250mb)
+          </i>
+        </p>
+      </div>
+    </div>
+  );
+}
+
+function WhisperLarge() {
+  return (
+    <div className="flex flex-col md:flex-row md:items-center gap-x-2 text-white mb-4 bg-blue-800/30 w-fit rounded-lg px-4 py-2">
+      <div className="gap-x-2 flex items-center">
+        <Gauge size={25} />
+        <p className="text-sm">
+          Using the <b>whisper-large</b> model on machines with limited RAM or
+          CPU can stall AnythingLLM when processing media files. This model is
+          substantially larger than the whisper-small.
+          <br />
+          We recommend at least 8GB of RAM and upload files &lt;10Mb.
+          <br />
+          <br />
+          <i>
+            This model will automatically download on the first use. (1.56GB)
+          </i>
+        </p>
+      </div>
+    </div>
+  );
+}
diff --git a/frontend/src/pages/GeneralSettings/TranscriptionPreference/index.jsx b/frontend/src/pages/GeneralSettings/TranscriptionPreference/index.jsx
index 5fbd196c3..07907af72 100644
--- a/frontend/src/pages/GeneralSettings/TranscriptionPreference/index.jsx
+++ b/frontend/src/pages/GeneralSettings/TranscriptionPreference/index.jsx
@@ -12,6 +12,23 @@ import LLMItem from "@/components/LLMSelection/LLMItem";
 import { CaretUpDown, MagnifyingGlass, X } from "@phosphor-icons/react";
 import CTAButton from "@/components/lib/CTAButton";
 
+const PROVIDERS = [
+  {
+    name: "OpenAI",
+    value: "openai",
+    logo: OpenAiLogo,
+    options: (settings) => <OpenAiWhisperOptions settings={settings} />,
+    description: "Leverage the OpenAI Whisper-large model using your API key.",
+  },
+  {
+    name: "AnythingLLM Built-In",
+    value: "local",
+    logo: AnythingLLMIcon,
+    options: (settings) => <NativeTranscriptionOptions settings={settings} />,
+    description: "Run a built-in whisper model on this instance privately.",
+  },
+];
+
 export default function TranscriptionModelPreference() {
   const [saving, setSaving] = useState(false);
   const [hasChanges, setHasChanges] = useState(false);
@@ -68,24 +85,6 @@ export default function TranscriptionModelPreference() {
     fetchKeys();
   }, []);
 
-  const PROVIDERS = [
-    {
-      name: "OpenAI",
-      value: "openai",
-      logo: OpenAiLogo,
-      options: <OpenAiWhisperOptions settings={settings} />,
-      description:
-        "Leverage the OpenAI Whisper-large model using your API key.",
-    },
-    {
-      name: "AnythingLLM Built-In",
-      value: "local",
-      logo: AnythingLLMIcon,
-      options: <NativeTranscriptionOptions settings={settings} />,
-      description: "Run a built-in whisper model on this instance privately.",
-    },
-  ];
-
   useEffect(() => {
     const filtered = PROVIDERS.filter((provider) =>
       provider.name.toLowerCase().includes(searchQuery.toLowerCase())
@@ -228,7 +227,7 @@ export default function TranscriptionModelPreference() {
                 {selectedProvider &&
                   PROVIDERS.find(
                     (provider) => provider.value === selectedProvider
-                  )?.options}
+                  )?.options(settings)}
               </div>
             </div>
           </form>
diff --git a/server/models/systemSettings.js b/server/models/systemSettings.js
index 68d1d0dde..c8e239f15 100644
--- a/server/models/systemSettings.js
+++ b/server/models/systemSettings.js
@@ -150,6 +150,8 @@ const SystemSettings = {
       // - then it can be shared.
       // --------------------------------------------------------
       WhisperProvider: process.env.WHISPER_PROVIDER || "local",
+      WhisperModelPref:
+        process.env.WHISPER_MODEL_PREF || "Xenova/whisper-small",
 
       // --------------------------------------------------------
       // TTS/STT  Selection Settings & Configs
diff --git a/server/utils/collectorApi/index.js b/server/utils/collectorApi/index.js
index 1a1431ac1..6971f640d 100644
--- a/server/utils/collectorApi/index.js
+++ b/server/utils/collectorApi/index.js
@@ -17,6 +17,7 @@ class CollectorApi {
   #attachOptions() {
     return {
       whisperProvider: process.env.WHISPER_PROVIDER || "local",
+      WhisperModelPref: process.env.WHISPER_MODEL_PREF,
       openAiKey: process.env.OPEN_AI_KEY || null,
     };
   }
diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js
index 8630d85a1..48c98e957 100644
--- a/server/utils/helpers/updateENV.js
+++ b/server/utils/helpers/updateENV.js
@@ -356,6 +356,11 @@ const KEY_MAPPING = {
     checks: [isNotEmpty, supportedTranscriptionProvider],
     postUpdate: [],
   },
+  WhisperModelPref: {
+    envKey: "WHISPER_MODEL_PREF",
+    checks: [validLocalWhisper],
+    postUpdate: [],
+  },
 
   // System Settings
   AuthToken: {
@@ -468,6 +473,16 @@ function supportedTTSProvider(input = "") {
   return validSelection ? null : `${input} is not a valid TTS provider.`;
 }
 
+function validLocalWhisper(input = "") {
+  const validSelection = [
+    "Xenova/whisper-small",
+    "Xenova/whisper-large",
+  ].includes(input);
+  return validSelection
+    ? null
+    : `${input} is not a valid Whisper model selection.`;
+}
+
 function supportedLLM(input = "") {
   const validSelection = [
     "openai",
-- 
GitLab