From 90df37582bcccba53282420eb61c8038c4699609 Mon Sep 17 00:00:00 2001
From: Sean Hatfield <seanhatfield5@gmail.com>
Date: Wed, 17 Jan 2024 12:59:25 -0800
Subject: [PATCH] Per workspace model selection (#582)

* WIP model selection per workspace (migrations and openai saves properly

* revert OpenAiOption

* add support for models per workspace for anthropic, localAi, ollama, openAi, and togetherAi

* remove unneeded comments

* update logic for when LLMProvider is reset, reset Ai provider files with master

* remove frontend/api reset of workspace chat and move logic to updateENV
add postUpdate callbacks to envs

* set preferred model for chat on class instantiation

* remove extra param

* linting

* remove unused var

* refactor chat model selection on workspace

* linting

* add fallback for base path to localai models

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
---
 .../Settings/ChatModelPreference/index.jsx    | 120 ++++++++++++++++++
 .../useGetProviderModels.js                   |  49 +++++++
 .../Modals/MangeWorkspace/Settings/index.jsx  |   8 +-
 .../Modals/MangeWorkspace/index.jsx           |   1 +
 .../GeneralSettings/LLMPreference/index.jsx   |   6 +-
 server/endpoints/api/system/index.js          |   2 +-
 server/endpoints/system.js                    |   6 +-
 server/models/workspace.js                    |  15 +++
 .../20240113013409_init/migration.sql         |   2 +
 server/prisma/schema.prisma                   |   1 +
 server/utils/AiProviders/anthropic/index.js   |   5 +-
 server/utils/AiProviders/azureOpenAi/index.js |   2 +-
 server/utils/AiProviders/gemini/index.js      |   5 +-
 server/utils/AiProviders/lmStudio/index.js    |   4 +-
 server/utils/AiProviders/localAi/index.js     |   4 +-
 server/utils/AiProviders/native/index.js      |   4 +-
 server/utils/AiProviders/ollama/index.js      |   4 +-
 server/utils/AiProviders/openAi/index.js      |   5 +-
 server/utils/AiProviders/togetherAi/index.js  |   4 +-
 server/utils/chats/index.js                   |   2 +-
 server/utils/chats/stream.js                  |   2 +-
 server/utils/helpers/customModels.js          |  13 +-
 server/utils/helpers/index.js                 |  20 +--
 server/utils/helpers/updateENV.js             |  32 +++--
 24 files changed, 263 insertions(+), 53 deletions(-)
 create mode 100644 frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx
 create mode 100644 frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js
 create mode 100644 server/prisma/migrations/20240113013409_init/migration.sql

diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx
new file mode 100644
index 000000000..ea03c09a9
--- /dev/null
+++ b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/index.jsx
@@ -0,0 +1,120 @@
+import useGetProviderModels, {
+  DISABLED_PROVIDERS,
+} from "./useGetProviderModels";
+
+export default function ChatModelSelection({
+  settings,
+  workspace,
+  setHasChanges,
+}) {
+  const { defaultModels, customModels, loading } = useGetProviderModels(
+    settings?.LLMProvider
+  );
+  if (DISABLED_PROVIDERS.includes(settings?.LLMProvider)) return null;
+
+  if (loading) {
+    return (
+      <div>
+        <div className="flex flex-col">
+          <label
+            htmlFor="name"
+            className="block text-sm font-medium text-white"
+          >
+            Chat model
+          </label>
+          <p className="text-white text-opacity-60 text-xs font-medium py-1.5">
+            The specific chat model that will be used for this workspace. If
+            empty, will use the system LLM preference.
+          </p>
+        </div>
+        <select
+          name="chatModel"
+          required={true}
+          disabled={true}
+          className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
+        >
+          <option disabled={true} selected={true}>
+            -- waiting for models --
+          </option>
+        </select>
+      </div>
+    );
+  }
+
+  return (
+    <div>
+      <div className="flex flex-col">
+        <label htmlFor="name" className="block text-sm font-medium text-white">
+          Chat model{" "}
+          <span className="font-normal">({settings?.LLMProvider})</span>
+        </label>
+        <p className="text-white text-opacity-60 text-xs font-medium py-1.5">
+          The specific chat model that will be used for this workspace. If
+          empty, will use the system LLM preference.
+        </p>
+      </div>
+
+      <select
+        name="chatModel"
+        required={true}
+        onChange={() => {
+          setHasChanges(true);
+        }}
+        className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
+      >
+        <option disabled={true} selected={workspace?.chatModel === null}>
+          System default
+        </option>
+        {defaultModels.length > 0 && (
+          <optgroup label="General models">
+            {defaultModels.map((model) => {
+              return (
+                <option
+                  key={model}
+                  value={model}
+                  selected={workspace?.chatModel === model}
+                >
+                  {model}
+                </option>
+              );
+            })}
+          </optgroup>
+        )}
+        {Array.isArray(customModels) && customModels.length > 0 && (
+          <optgroup label="Custom models">
+            {customModels.map((model) => {
+              return (
+                <option
+                  key={model.id}
+                  value={model.id}
+                  selected={workspace?.chatModel === model.id}
+                >
+                  {model.id}
+                </option>
+              );
+            })}
+          </optgroup>
+        )}
+        {/* For providers like TogetherAi where we partition model by creator entity. */}
+        {!Array.isArray(customModels) &&
+          Object.keys(customModels).length > 0 && (
+            <>
+              {Object.entries(customModels).map(([organization, models]) => (
+                <optgroup key={organization} label={organization}>
+                  {models.map((model) => (
+                    <option
+                      key={model.id}
+                      value={model.id}
+                      selected={workspace?.chatModel === model.id}
+                    >
+                      {model.name}
+                    </option>
+                  ))}
+                </optgroup>
+              ))}
+            </>
+          )}
+      </select>
+    </div>
+  );
+}
diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js
new file mode 100644
index 000000000..eae1b4adc
--- /dev/null
+++ b/frontend/src/components/Modals/MangeWorkspace/Settings/ChatModelPreference/useGetProviderModels.js
@@ -0,0 +1,49 @@
+import System from "@/models/system";
+import { useEffect, useState } from "react";
+
+// Providers which cannot use this feature for workspace<>model selection
+export const DISABLED_PROVIDERS = ["azure", "lmstudio"];
+const PROVIDER_DEFAULT_MODELS = {
+  openai: ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-32k"],
+  gemini: ["gemini-pro"],
+  anthropic: ["claude-2", "claude-instant-1"],
+  azure: [],
+  lmstudio: [],
+  localai: [],
+  ollama: [],
+  togetherai: [],
+  native: [],
+};
+
+// For togetherAi, which has a large model list - we subgroup the options
+// by their creator organization (eg: Meta, Mistral, etc)
+// which makes selection easier to read.
+function groupModels(models) {
+  return models.reduce((acc, model) => {
+    acc[model.organization] = acc[model.organization] || [];
+    acc[model.organization].push(model);
+    return acc;
+  }, {});
+}
+
+export default function useGetProviderModels(provider = null) {
+  const [defaultModels, setDefaultModels] = useState([]);
+  const [customModels, setCustomModels] = useState([]);
+  const [loading, setLoading] = useState(true);
+
+  useEffect(() => {
+    async function fetchProviderModels() {
+      if (!provider) return;
+      const { models = [] } = await System.customModels(provider);
+      if (PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider))
+        setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]);
+      provider === "togetherai"
+        ? setCustomModels(groupModels(models))
+        : setCustomModels(models);
+      setLoading(false);
+    }
+    fetchProviderModels();
+  }, [provider]);
+
+  return { defaultModels, customModels, loading };
+}
diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
index 2fce91e1f..a3089d688 100644
--- a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
+++ b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
@@ -6,6 +6,7 @@ import System from "../../../../models/system";
 import PreLoader from "../../../Preloader";
 import { useParams } from "react-router-dom";
 import showToast from "../../../../utils/toast";
+import ChatModelPreference from "./ChatModelPreference";
 
 // Ensure that a type is correct before sending the body
 // to the backend.
@@ -26,7 +27,7 @@ function castToType(key, value) {
   return definitions[key].cast(value);
 }
 
-export default function WorkspaceSettings({ active, workspace }) {
+export default function WorkspaceSettings({ active, workspace, settings }) {
   const { slug } = useParams();
   const formEl = useRef(null);
   const [saving, setSaving] = useState(false);
@@ -99,6 +100,11 @@ export default function WorkspaceSettings({ active, workspace }) {
           <div className="flex">
             <div className="flex flex-col gap-y-4 w-1/2">
               <div className="w-3/4 flex flex-col gap-y-4">
+                <ChatModelPreference
+                  settings={settings}
+                  workspace={workspace}
+                  setHasChanges={setHasChanges}
+                />
                 <div>
                   <div className="flex flex-col">
                     <label
diff --git a/frontend/src/components/Modals/MangeWorkspace/index.jsx b/frontend/src/components/Modals/MangeWorkspace/index.jsx
index 9092d0d51..8fb67a499 100644
--- a/frontend/src/components/Modals/MangeWorkspace/index.jsx
+++ b/frontend/src/components/Modals/MangeWorkspace/index.jsx
@@ -117,6 +117,7 @@ const ManageWorkspace = ({ hideModal = noop, providedSlug = null }) => {
               <WorkspaceSettings
                 active={selectedTab === "settings"} // To force reload live sub-components like VectorCount
                 workspace={workspace}
+                settings={settings}
               />
             </div>
           </Suspense>
diff --git a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx
index 287716222..bd6ae511d 100644
--- a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx
+++ b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx
@@ -30,19 +30,17 @@ export default function GeneralLLMPreference() {
   const [hasChanges, setHasChanges] = useState(false);
   const [settings, setSettings] = useState(null);
   const [loading, setLoading] = useState(true);
-
   const [searchQuery, setSearchQuery] = useState("");
   const [filteredLLMs, setFilteredLLMs] = useState([]);
   const [selectedLLM, setSelectedLLM] = useState(null);
-
   const isHosted = window.location.hostname.includes("useanything.com");
 
   const handleSubmit = async (e) => {
     e.preventDefault();
     const form = e.target;
-    const data = {};
+    const data = { LLMProvider: selectedLLM };
     const formData = new FormData(form);
-    data.LLMProvider = selectedLLM;
+
     for (var [key, value] of formData.entries()) data[key] = value;
     const { error } = await System.updateSystem(data);
     setSaving(true);
diff --git a/server/endpoints/api/system/index.js b/server/endpoints/api/system/index.js
index 3548c3068..b18019b14 100644
--- a/server/endpoints/api/system/index.js
+++ b/server/endpoints/api/system/index.js
@@ -139,7 +139,7 @@ function apiSystemEndpoints(app) {
       */
       try {
         const body = reqBody(request);
-        const { newValues, error } = updateENV(body);
+        const { newValues, error } = await updateENV(body);
         if (process.env.NODE_ENV === "production") await dumpENV();
         response.status(200).json({ newValues, error });
       } catch (e) {
diff --git a/server/endpoints/system.js b/server/endpoints/system.js
index 15db895ad..e699cf84c 100644
--- a/server/endpoints/system.js
+++ b/server/endpoints/system.js
@@ -290,7 +290,7 @@ function systemEndpoints(app) {
         }
 
         const body = reqBody(request);
-        const { newValues, error } = updateENV(body);
+        const { newValues, error } = await updateENV(body);
         if (process.env.NODE_ENV === "production") await dumpENV();
         response.status(200).json({ newValues, error });
       } catch (e) {
@@ -312,7 +312,7 @@ function systemEndpoints(app) {
         }
 
         const { usePassword, newPassword } = reqBody(request);
-        const { error } = updateENV(
+        const { error } = await updateENV(
           {
             AuthToken: usePassword ? newPassword : "",
             JWTSecret: usePassword ? v4() : "",
@@ -355,7 +355,7 @@ function systemEndpoints(app) {
           message_limit: 25,
         });
 
-        updateENV(
+        await updateENV(
           {
             AuthToken: "",
             JWTSecret: process.env.JWT_SECRET || v4(),
diff --git a/server/models/workspace.js b/server/models/workspace.js
index 9139c25e9..6de8053e9 100644
--- a/server/models/workspace.js
+++ b/server/models/workspace.js
@@ -14,6 +14,7 @@ const Workspace = {
     "lastUpdatedAt",
     "openAiPrompt",
     "similarityThreshold",
+    "chatModel",
   ],
 
   new: async function (name = null, creatorId = null) {
@@ -191,6 +192,20 @@ const Workspace = {
       return { success: false, error: error.message };
     }
   },
+
+  resetWorkspaceChatModels: async () => {
+    try {
+      await prisma.workspaces.updateMany({
+        data: {
+          chatModel: null,
+        },
+      });
+      return { success: true, error: null };
+    } catch (error) {
+      console.error("Error resetting workspace chat models:", error.message);
+      return { success: false, error: error.message };
+    }
+  },
 };
 
 module.exports = { Workspace };
diff --git a/server/prisma/migrations/20240113013409_init/migration.sql b/server/prisma/migrations/20240113013409_init/migration.sql
new file mode 100644
index 000000000..09b9448ec
--- /dev/null
+++ b/server/prisma/migrations/20240113013409_init/migration.sql
@@ -0,0 +1,2 @@
+-- AlterTable
+ALTER TABLE "workspaces" ADD COLUMN "chatModel" TEXT;
diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma
index e9aa8a8a5..2f632a46a 100644
--- a/server/prisma/schema.prisma
+++ b/server/prisma/schema.prisma
@@ -93,6 +93,7 @@ model workspaces {
   lastUpdatedAt       DateTime              @default(now())
   openAiPrompt        String?
   similarityThreshold Float?                @default(0.25)
+  chatModel           String?
   workspace_users     workspace_users[]
   documents           workspace_documents[]
 }
diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js
index 709333231..17f2abc4a 100644
--- a/server/utils/AiProviders/anthropic/index.js
+++ b/server/utils/AiProviders/anthropic/index.js
@@ -2,7 +2,7 @@ const { v4 } = require("uuid");
 const { chatPrompt } = require("../../chats");
 
 class AnthropicLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, modelPreference = null) {
     if (!process.env.ANTHROPIC_API_KEY)
       throw new Error("No Anthropic API key was set.");
 
@@ -12,7 +12,8 @@ class AnthropicLLM {
       apiKey: process.env.ANTHROPIC_API_KEY,
     });
     this.anthropic = anthropic;
-    this.model = process.env.ANTHROPIC_MODEL_PREF || "claude-2";
+    this.model =
+      modelPreference || process.env.ANTHROPIC_MODEL_PREF || "claude-2";
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
       system: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js
index 185dac021..f59fc51fa 100644
--- a/server/utils/AiProviders/azureOpenAi/index.js
+++ b/server/utils/AiProviders/azureOpenAi/index.js
@@ -2,7 +2,7 @@ const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi");
 const { chatPrompt } = require("../../chats");
 
 class AzureOpenAiLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, _modelPreference = null) {
     const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
     if (!process.env.AZURE_OPENAI_ENDPOINT)
       throw new Error("No Azure API endpoint was set.");
diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js
index 03388e3e2..348c8f5ed 100644
--- a/server/utils/AiProviders/gemini/index.js
+++ b/server/utils/AiProviders/gemini/index.js
@@ -1,14 +1,15 @@
 const { chatPrompt } = require("../../chats");
 
 class GeminiLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, modelPreference = null) {
     if (!process.env.GEMINI_API_KEY)
       throw new Error("No Gemini API key was set.");
 
     // Docs: https://ai.google.dev/tutorials/node_quickstart
     const { GoogleGenerativeAI } = require("@google/generative-ai");
     const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
-    this.model = process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
+    this.model =
+      modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
     this.gemini = genAI.getGenerativeModel({ model: this.model });
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js
index 28c107df0..614808034 100644
--- a/server/utils/AiProviders/lmStudio/index.js
+++ b/server/utils/AiProviders/lmStudio/index.js
@@ -2,7 +2,7 @@ const { chatPrompt } = require("../../chats");
 
 //  hybrid of openAi LLM chat completion for LMStudio
 class LMStudioLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, _modelPreference = null) {
     if (!process.env.LMSTUDIO_BASE_PATH)
       throw new Error("No LMStudio API Base Path was set.");
 
@@ -12,7 +12,7 @@ class LMStudioLLM {
     });
     this.lmstudio = new OpenAIApi(config);
     // When using LMStudios inference server - the model param is not required so
-    // we can stub it here.
+    // we can stub it here. LMStudio can only run one model at a time.
     this.model = "model-placeholder";
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/localAi/index.js b/server/utils/AiProviders/localAi/index.js
index 84954c994..6623ac88e 100644
--- a/server/utils/AiProviders/localAi/index.js
+++ b/server/utils/AiProviders/localAi/index.js
@@ -1,7 +1,7 @@
 const { chatPrompt } = require("../../chats");
 
 class LocalAiLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, modelPreference = null) {
     if (!process.env.LOCAL_AI_BASE_PATH)
       throw new Error("No LocalAI Base Path was set.");
 
@@ -15,7 +15,7 @@ class LocalAiLLM {
         : {}),
     });
     this.openai = new OpenAIApi(config);
-    this.model = process.env.LOCAL_AI_MODEL_PREF;
+    this.model = modelPreference || process.env.LOCAL_AI_MODEL_PREF;
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
       system: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/native/index.js b/server/utils/AiProviders/native/index.js
index faac4fa03..66cc84d0c 100644
--- a/server/utils/AiProviders/native/index.js
+++ b/server/utils/AiProviders/native/index.js
@@ -10,11 +10,11 @@ const ChatLlamaCpp = (...args) =>
   );
 
 class NativeLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, modelPreference = null) {
     if (!process.env.NATIVE_LLM_MODEL_PREF)
       throw new Error("No local Llama model was set.");
 
-    this.model = process.env.NATIVE_LLM_MODEL_PREF || null;
+    this.model = modelPreference || process.env.NATIVE_LLM_MODEL_PREF || null;
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
       system: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js
index 55205c23d..fce96f369 100644
--- a/server/utils/AiProviders/ollama/index.js
+++ b/server/utils/AiProviders/ollama/index.js
@@ -3,12 +3,12 @@ const { StringOutputParser } = require("langchain/schema/output_parser");
 
 // Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
 class OllamaAILLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, modelPreference = null) {
     if (!process.env.OLLAMA_BASE_PATH)
       throw new Error("No Ollama Base Path was set.");
 
     this.basePath = process.env.OLLAMA_BASE_PATH;
-    this.model = process.env.OLLAMA_MODEL_PREF;
+    this.model = modelPreference || process.env.OLLAMA_MODEL_PREF;
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
       system: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js
index ccc7ba0e9..038d201d1 100644
--- a/server/utils/AiProviders/openAi/index.js
+++ b/server/utils/AiProviders/openAi/index.js
@@ -2,7 +2,7 @@ const { OpenAiEmbedder } = require("../../EmbeddingEngines/openAi");
 const { chatPrompt } = require("../../chats");
 
 class OpenAiLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, modelPreference = null) {
     const { Configuration, OpenAIApi } = require("openai");
     if (!process.env.OPEN_AI_KEY) throw new Error("No OpenAI API key was set.");
 
@@ -10,7 +10,8 @@ class OpenAiLLM {
       apiKey: process.env.OPEN_AI_KEY,
     });
     this.openai = new OpenAIApi(config);
-    this.model = process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
+    this.model =
+      modelPreference || process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo";
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
       system: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/AiProviders/togetherAi/index.js b/server/utils/AiProviders/togetherAi/index.js
index df64c413e..44061dd0a 100644
--- a/server/utils/AiProviders/togetherAi/index.js
+++ b/server/utils/AiProviders/togetherAi/index.js
@@ -6,7 +6,7 @@ function togetherAiModels() {
 }
 
 class TogetherAiLLM {
-  constructor(embedder = null) {
+  constructor(embedder = null, modelPreference = null) {
     const { Configuration, OpenAIApi } = require("openai");
     if (!process.env.TOGETHER_AI_API_KEY)
       throw new Error("No TogetherAI API key was set.");
@@ -16,7 +16,7 @@ class TogetherAiLLM {
       apiKey: process.env.TOGETHER_AI_API_KEY,
     });
     this.openai = new OpenAIApi(config);
-    this.model = process.env.TOGETHER_AI_MODEL_PREF;
+    this.model = modelPreference || process.env.TOGETHER_AI_MODEL_PREF;
     this.limits = {
       history: this.promptWindowLimit() * 0.15,
       system: this.promptWindowLimit() * 0.15,
diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js
index 7fdb47344..d63de47d5 100644
--- a/server/utils/chats/index.js
+++ b/server/utils/chats/index.js
@@ -71,7 +71,7 @@ async function chatWithWorkspace(
     return await VALID_COMMANDS[command](workspace, message, uuid, user);
   }
 
-  const LLMConnector = getLLMProvider();
+  const LLMConnector = getLLMProvider(workspace?.chatModel);
   const VectorDb = getVectorDbClass();
   const { safe, reasons = [] } = await LLMConnector.isSafe(message);
   if (!safe) {
diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js
index 04bb72b90..ceea8d7d2 100644
--- a/server/utils/chats/stream.js
+++ b/server/utils/chats/stream.js
@@ -30,7 +30,7 @@ async function streamChatWithWorkspace(
     return;
   }
 
-  const LLMConnector = getLLMProvider();
+  const LLMConnector = getLLMProvider(workspace?.chatModel);
   const VectorDb = getVectorDbClass();
   const { safe, reasons = [] } = await LLMConnector.isSafe(message);
   if (!safe) {
diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js
index 54976895e..87fe976ec 100644
--- a/server/utils/helpers/customModels.js
+++ b/server/utils/helpers/customModels.js
@@ -17,7 +17,7 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
     case "localai":
       return await localAIModels(basePath, apiKey);
     case "ollama":
-      return await ollamaAIModels(basePath, apiKey);
+      return await ollamaAIModels(basePath);
     case "togetherai":
       return await getTogetherAiModels();
     case "native-llm":
@@ -53,7 +53,7 @@ async function openAiModels(apiKey = null) {
 async function localAIModels(basePath = null, apiKey = null) {
   const { Configuration, OpenAIApi } = require("openai");
   const config = new Configuration({
-    basePath,
+    basePath: basePath || process.env.LOCAL_AI_BASE_PATH,
     apiKey: apiKey || process.env.LOCAL_AI_API_KEY,
   });
   const openai = new OpenAIApi(config);
@@ -70,13 +70,14 @@ async function localAIModels(basePath = null, apiKey = null) {
   return { models, error: null };
 }
 
-async function ollamaAIModels(basePath = null, _apiKey = null) {
+async function ollamaAIModels(basePath = null) {
   let url;
   try {
-    new URL(basePath);
-    if (basePath.split("").slice(-1)?.[0] === "/")
+    let urlPath = basePath ?? process.env.OLLAMA_BASE_PATH;
+    new URL(urlPath);
+    if (urlPath.split("").slice(-1)?.[0] === "/")
       throw new Error("BasePath Cannot end in /!");
-    url = basePath;
+    url = urlPath;
   } catch {
     return { models: [], error: "Not a valid URL." };
   }
diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js
index 1685acc1a..2b1f3dacf 100644
--- a/server/utils/helpers/index.js
+++ b/server/utils/helpers/index.js
@@ -24,37 +24,37 @@ function getVectorDbClass() {
   }
 }
 
-function getLLMProvider() {
+function getLLMProvider(modelPreference = null) {
   const vectorSelection = process.env.LLM_PROVIDER || "openai";
   const embedder = getEmbeddingEngineSelection();
   switch (vectorSelection) {
     case "openai":
       const { OpenAiLLM } = require("../AiProviders/openAi");
-      return new OpenAiLLM(embedder);
+      return new OpenAiLLM(embedder, modelPreference);
     case "azure":
       const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi");
-      return new AzureOpenAiLLM(embedder);
+      return new AzureOpenAiLLM(embedder, modelPreference);
     case "anthropic":
       const { AnthropicLLM } = require("../AiProviders/anthropic");
-      return new AnthropicLLM(embedder);
+      return new AnthropicLLM(embedder, modelPreference);
     case "gemini":
       const { GeminiLLM } = require("../AiProviders/gemini");
-      return new GeminiLLM(embedder);
+      return new GeminiLLM(embedder, modelPreference);
     case "lmstudio":
       const { LMStudioLLM } = require("../AiProviders/lmStudio");
-      return new LMStudioLLM(embedder);
+      return new LMStudioLLM(embedder, modelPreference);
     case "localai":
       const { LocalAiLLM } = require("../AiProviders/localAi");
-      return new LocalAiLLM(embedder);
+      return new LocalAiLLM(embedder, modelPreference);
     case "ollama":
       const { OllamaAILLM } = require("../AiProviders/ollama");
-      return new OllamaAILLM(embedder);
+      return new OllamaAILLM(embedder, modelPreference);
     case "togetherai":
       const { TogetherAiLLM } = require("../AiProviders/togetherAi");
-      return new TogetherAiLLM(embedder);
+      return new TogetherAiLLM(embedder, modelPreference);
     case "native":
       const { NativeLLM } = require("../AiProviders/native");
-      return new NativeLLM(embedder);
+      return new NativeLLM(embedder, modelPreference);
     default:
       throw new Error("ENV: No LLM_PROVIDER value found in environment!");
   }
diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js
index c699cf2df..5c43da519 100644
--- a/server/utils/helpers/updateENV.js
+++ b/server/utils/helpers/updateENV.js
@@ -2,6 +2,7 @@ const KEY_MAPPING = {
   LLMProvider: {
     envKey: "LLM_PROVIDER",
     checks: [isNotEmpty, supportedLLM],
+    postUpdate: [wipeWorkspaceModelPreference],
   },
   // OpenAI Settings
   OpenAiKey: {
@@ -362,11 +363,20 @@ function validDockerizedUrl(input = "") {
   return null;
 }
 
+// If the LLMProvider has changed we need to reset all workspace model preferences to
+// null since the provider<>model name combination will be invalid for whatever the new
+// provider is.
+async function wipeWorkspaceModelPreference(key, prev, next) {
+  if (prev === next) return;
+  const { Workspace } = require("../../models/workspace");
+  await Workspace.resetWorkspaceChatModels();
+}
+
 // This will force update .env variables which for any which reason were not able to be parsed or
 // read from an ENV file as this seems to be a complicating step for many so allowing people to write
 // to the process will at least alleviate that issue. It does not perform comprehensive validity checks or sanity checks
 // and is simply for debugging when the .env not found issue many come across.
-function updateENV(newENVs = {}, force = false) {
+async function updateENV(newENVs = {}, force = false) {
   let error = "";
   const validKeys = Object.keys(KEY_MAPPING);
   const ENV_KEYS = Object.keys(newENVs).filter(
@@ -374,21 +384,25 @@ function updateENV(newENVs = {}, force = false) {
   );
   const newValues = {};
 
-  ENV_KEYS.forEach((key) => {
-    const { envKey, checks } = KEY_MAPPING[key];
-    const value = newENVs[key];
+  for (const key of ENV_KEYS) {
+    const { envKey, checks, postUpdate = [] } = KEY_MAPPING[key];
+    const prevValue = process.env[envKey];
+    const nextValue = newENVs[key];
     const errors = checks
-      .map((validityCheck) => validityCheck(value, force))
+      .map((validityCheck) => validityCheck(nextValue, force))
       .filter((err) => typeof err === "string");
 
     if (errors.length > 0) {
       error += errors.join("\n");
-      return;
+      break;
     }
 
-    newValues[key] = value;
-    process.env[envKey] = value;
-  });
+    newValues[key] = nextValue;
+    process.env[envKey] = nextValue;
+
+    for (const postUpdateFunc of postUpdate)
+      await postUpdateFunc(key, prevValue, nextValue);
+  }
 
   return { newValues, error: error?.length > 0 ? error : false };
 }
-- 
GitLab