From 67c85f15507f60c9d0899a4f00c3093ed1e40216 Mon Sep 17 00:00:00 2001 From: Timothy Carambat <rambat1010@gmail.com> Date: Tue, 31 Oct 2023 11:38:28 -0700 Subject: [PATCH] Implement retrieval and use of fine-tune models (#314) * Implement retrieval and use of fine-tune models Cleanup LLM selection code resolves #311 * Cleanup from PR bot --- .../LLMSelection/AnthropicAiOptions/index.jsx | 63 +++++++ .../LLMSelection/AzureAiOptions/index.jsx | 69 +++++++ .../LLMProviderOption/index.jsx | 0 .../LLMSelection/OpenAiOptions/index.jsx | 117 ++++++++++++ frontend/src/models/system.js | 20 ++ .../EmbeddingPreference/index.jsx | 2 +- .../GeneralSettings/LLMPreference/index.jsx | 174 +----------------- .../Steps/EmbeddingSelection/index.jsx | 4 +- .../Steps/LLMSelection/index.jsx | 160 +--------------- server/endpoints/system.js | 21 ++- server/utils/AiProviders/openAi/index.js | 21 ++- server/utils/helpers/customModels.js | 38 ++++ server/utils/helpers/updateENV.js | 11 +- 13 files changed, 360 insertions(+), 340 deletions(-) create mode 100644 frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx create mode 100644 frontend/src/components/LLMSelection/AzureAiOptions/index.jsx rename frontend/src/components/{ => LLMSelection}/LLMProviderOption/index.jsx (100%) create mode 100644 frontend/src/components/LLMSelection/OpenAiOptions/index.jsx create mode 100644 server/utils/helpers/customModels.js diff --git a/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx b/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx new file mode 100644 index 000000000..c2b29468c --- /dev/null +++ b/frontend/src/components/LLMSelection/AnthropicAiOptions/index.jsx @@ -0,0 +1,63 @@ +import { Info } from "@phosphor-icons/react"; +import paths from "../../../utils/paths"; + +export default function AnthropicAiOptions({ settings, showAlert = false }) { + return ( + <div className="w-full flex flex-col"> + {showAlert && ( + <div className="flex flex-col md:flex-row md:items-center gap-x-2 text-white mb-6 bg-blue-800/30 w-fit rounded-lg px-4 py-2"> + <div className="gap-x-2 flex items-center"> + <Info size={12} className="hidden md:visible" /> + <p className="text-sm md:text-base"> + Anthropic as your LLM requires you to set an embedding service to + use. + </p> + </div> + <a + href={paths.general.embeddingPreference()} + className="text-sm md:text-base my-2 underline" + > + Manage embedding → + </a> + </div> + )} + <div className="w-full flex items-center gap-4"> + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Anthropic Claude-2 API Key + </label> + <input + type="password" + name="AnthropicApiKey" + className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" + placeholder="Anthropic Claude-2 API Key" + defaultValue={settings?.AnthropicApiKey ? "*".repeat(20) : ""} + required={true} + autoComplete="off" + spellCheck={false} + /> + </div> + + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Chat Model Selection + </label> + <select + name="AnthropicModelPref" + defaultValue={settings?.AnthropicModelPref || "claude-2"} + required={true} + className="bg-zinc-900 border border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + {["claude-2"].map((model) => { + return ( + <option key={model} value={model}> + {model} + </option> + ); + })} + </select> + </div> + </div> + </div> + ); +} diff --git a/frontend/src/components/LLMSelection/AzureAiOptions/index.jsx b/frontend/src/components/LLMSelection/AzureAiOptions/index.jsx new file mode 100644 index 000000000..99b04fa8e --- /dev/null +++ b/frontend/src/components/LLMSelection/AzureAiOptions/index.jsx @@ -0,0 +1,69 @@ +export default function AzureAiOptions({ settings }) { + return ( + <> + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Azure Service Endpoint + </label> + <input + type="url" + name="AzureOpenAiEndpoint" + className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" + placeholder="https://my-azure.openai.azure.com" + defaultValue={settings?.AzureOpenAiEndpoint} + required={true} + autoComplete="off" + spellCheck={false} + /> + </div> + + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + API Key + </label> + <input + type="password" + name="AzureOpenAiKey" + className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" + placeholder="Azure OpenAI API Key" + defaultValue={settings?.AzureOpenAiKey ? "*".repeat(20) : ""} + required={true} + autoComplete="off" + spellCheck={false} + /> + </div> + + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Chat Deployment Name + </label> + <input + type="text" + name="AzureOpenAiModelPref" + className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" + placeholder="Azure OpenAI chat model deployment name" + defaultValue={settings?.AzureOpenAiModelPref} + required={true} + autoComplete="off" + spellCheck={false} + /> + </div> + + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Embedding Deployment Name + </label> + <input + type="text" + name="AzureOpenAiEmbeddingModelPref" + className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" + placeholder="Azure OpenAI embedding model deployment name" + defaultValue={settings?.AzureOpenAiEmbeddingModelPref} + required={true} + autoComplete="off" + spellCheck={false} + /> + </div> + </> + ); +} diff --git a/frontend/src/components/LLMProviderOption/index.jsx b/frontend/src/components/LLMSelection/LLMProviderOption/index.jsx similarity index 100% rename from frontend/src/components/LLMProviderOption/index.jsx rename to frontend/src/components/LLMSelection/LLMProviderOption/index.jsx diff --git a/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx b/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx new file mode 100644 index 000000000..b6957419e --- /dev/null +++ b/frontend/src/components/LLMSelection/OpenAiOptions/index.jsx @@ -0,0 +1,117 @@ +import { useState, useEffect } from "react"; +import System from "../../../models/system"; + +export default function OpenAiOptions({ settings }) { + const [inputValue, setInputValue] = useState(settings?.OpenAiKey); + const [openAIKey, setOpenAIKey] = useState(settings?.OpenAiKey); + function updateOpenAiKey() { + setOpenAIKey(inputValue); + } + + return ( + <> + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + API Key + </label> + <input + type="password" + name="OpenAiKey" + className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" + placeholder="OpenAI API Key" + defaultValue={settings?.OpenAiKey ? "*".repeat(20) : ""} + required={true} + autoComplete="off" + spellCheck={false} + onChange={(e) => setInputValue(e.target.value)} + onBlur={updateOpenAiKey} + /> + </div> + <OpenAIModelSelection settings={settings} apiKey={openAIKey} /> + </> + ); +} + +function OpenAIModelSelection({ apiKey, settings }) { + const [customModels, setCustomModels] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function findCustomModels() { + if (!apiKey) { + setCustomModels([]); + setLoading(false); + return; + } + setLoading(true); + const { models } = await System.customModels( + "openai", + typeof apiKey === "boolean" ? null : apiKey + ); + setCustomModels(models || []); + setLoading(false); + } + findCustomModels(); + }, [apiKey]); + + if (loading) { + return ( + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Chat Model Selection + </label> + <select + name="OpenAiModelPref" + disabled={true} + className="bg-zinc-900 border border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + <option disabled={true} selected={true}> + -- loading available models -- + </option> + </select> + </div> + ); + } + + return ( + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-4"> + Chat Model Selection + </label> + <select + name="OpenAiModelPref" + required={true} + className="bg-zinc-900 border border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + <optgroup label="General LLM models"> + {["gpt-3.5-turbo", "gpt-4"].map((model) => { + return ( + <option + key={model} + value={model} + selected={settings.OpenAiModelPref === model} + > + {model} + </option> + ); + })} + </optgroup> + {customModels.length > 0 && ( + <optgroup label="Your fine-tuned models"> + {customModels.map((model) => { + return ( + <option + key={model.id} + value={model.id} + selected={settings.OpenAiModelPref === model.id} + > + {model.id} + </option> + ); + })} + </optgroup> + )} + </select> + </div> + ); +} diff --git a/frontend/src/models/system.js b/frontend/src/models/system.js index 1b2e34e31..a90bddb23 100644 --- a/frontend/src/models/system.js +++ b/frontend/src/models/system.js @@ -319,6 +319,26 @@ const System = { return false; }); }, + customModels: async function (provider, apiKey) { + return fetch(`${API_BASE}/system/custom-models`, { + method: "POST", + headers: baseHeaders(), + body: JSON.stringify({ + provider, + apiKey, + }), + }) + .then((res) => { + if (!res.ok) { + throw new Error(res.statusText || "Error finding custom models."); + } + return res.json(); + }) + .catch((e) => { + console.error(e); + return { models: [], error: e.message }; + }); + }, }; export default System; diff --git a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx index 39206e61f..dd5b14339 100644 --- a/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/EmbeddingPreference/index.jsx @@ -8,7 +8,7 @@ import showToast from "../../../utils/toast"; import OpenAiLogo from "../../../media/llmprovider/openai.png"; import AzureOpenAiLogo from "../../../media/llmprovider/azure.png"; import PreLoader from "../../../components/Preloader"; -import LLMProviderOption from "../../../components/LLMProviderOption"; +import LLMProviderOption from "../../../components/LLMSelection/LLMProviderOption"; export default function GeneralEmbeddingPreference() { const [saving, setSaving] = useState(false); diff --git a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx index 0d4df449c..e933ab5ee 100644 --- a/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx +++ b/frontend/src/pages/GeneralSettings/LLMPreference/index.jsx @@ -9,9 +9,10 @@ import OpenAiLogo from "../../../media/llmprovider/openai.png"; import AzureOpenAiLogo from "../../../media/llmprovider/azure.png"; import AnthropicLogo from "../../../media/llmprovider/anthropic.png"; import PreLoader from "../../../components/Preloader"; -import LLMProviderOption from "../../../components/LLMProviderOption"; -import { Info } from "@phosphor-icons/react"; -import paths from "../../../utils/paths"; +import LLMProviderOption from "../../../components/LLMSelection/LLMProviderOption"; +import OpenAiOptions from "../../../components/LLMSelection/OpenAiOptions"; +import AzureAiOptions from "../../../components/LLMSelection/AzureAiOptions"; +import AnthropicAiOptions from "../../../components/LLMSelection/AnthropicAiOptions"; export default function GeneralLLMPreference() { const [saving, setSaving] = useState(false); @@ -132,174 +133,13 @@ export default function GeneralLLMPreference() { </div> <div className="mt-10 flex flex-wrap gap-4 max-w-[800px]"> {llmChoice === "openai" && ( - <> - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - API Key - </label> - <input - type="text" - name="OpenAiKey" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="OpenAI API Key" - defaultValue={settings?.OpenAiKey ? "*".repeat(20) : ""} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Chat Model Selection - </label> - <select - name="OpenAiModelPref" - defaultValue={settings?.OpenAiModelPref} - required={true} - className="bg-zinc-900 border border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" - > - {["gpt-3.5-turbo", "gpt-4"].map((model) => { - return ( - <option key={model} value={model}> - {model} - </option> - ); - })} - </select> - </div> - </> + <OpenAiOptions settings={settings} /> )} - {llmChoice === "azure" && ( - <> - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Azure Service Endpoint - </label> - <input - type="url" - name="AzureOpenAiEndpoint" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="https://my-azure.openai.azure.com" - defaultValue={settings?.AzureOpenAiEndpoint} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - API Key - </label> - <input - type="password" - name="AzureOpenAiKey" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Azure OpenAI API Key" - defaultValue={ - settings?.AzureOpenAiKey ? "*".repeat(20) : "" - } - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Chat Deployment Name - </label> - <input - type="text" - name="AzureOpenAiModelPref" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Azure OpenAI chat model deployment name" - defaultValue={settings?.AzureOpenAiModelPref} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Embedding Deployment Name - </label> - <input - type="text" - name="AzureOpenAiEmbeddingModelPref" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Azure OpenAI embedding model deployment name" - defaultValue={settings?.AzureOpenAiEmbeddingModelPref} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - </> + <AzureAiOptions settings={settings} /> )} - {llmChoice === "anthropic" && ( - <div className="w-full flex flex-col"> - <div className="flex flex-col md:flex-row md:items-center gap-x-2 text-white mb-6 bg-blue-800/30 w-fit rounded-lg px-4 py-2"> - <div className="gap-x-2 flex items-center"> - <Info size={12} className="hidden md:visible" /> - <p className="text-sm md:text-base"> - Anthropic as your LLM requires you to set an embedding - service to use. - </p> - </div> - <a - href={paths.general.embeddingPreference()} - className="text-sm md:text-base my-2 underline" - > - Manage embedding → - </a> - </div> - <div className="w-full flex items-center gap-4"> - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Anthropic Claude-2 API Key - </label> - <input - type="text" - name="AnthropicApiKey" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Anthropic Claude-2 API Key" - defaultValue={ - settings?.AnthropicApiKey ? "*".repeat(20) : "" - } - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Chat Model Selection - </label> - <select - name="AnthropicModelPref" - defaultValue={ - settings?.AnthropicModelPref || "claude-2" - } - required={true} - className="bg-zinc-900 border border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" - > - {["claude-2"].map((model) => { - return ( - <option key={model} value={model}> - {model} - </option> - ); - })} - </select> - </div> - </div> - </div> + <AnthropicAiOptions settings={settings} showAlert={true} /> )} </div> </div> diff --git a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx index aadf37980..27fe27525 100644 --- a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx +++ b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/EmbeddingSelection/index.jsx @@ -3,12 +3,10 @@ import OpenAiLogo from "../../../../../media/llmprovider/openai.png"; import AzureOpenAiLogo from "../../../../../media/llmprovider/azure.png"; import System from "../../../../../models/system"; import PreLoader from "../../../../../components/Preloader"; -import LLMProviderOption from "../../../../../components/LLMProviderOption"; +import LLMProviderOption from "../../../../../components/LLMSelection/LLMProviderOption"; function EmbeddingSelection({ nextStep, prevStep, currentStep, goToStep }) { const [embeddingChoice, setEmbeddingChoice] = useState("openai"); - const [llmChoice, setLLMChoice] = useState("openai"); - const [settings, setSettings] = useState(null); const [loading, setLoading] = useState(true); diff --git a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx index b3d141f6d..a813dbd9e 100644 --- a/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx +++ b/frontend/src/pages/OnboardingFlow/OnboardingModal/Steps/LLMSelection/index.jsx @@ -1,11 +1,13 @@ import React, { memo, useEffect, useState } from "react"; - import OpenAiLogo from "../../../../../media/llmprovider/openai.png"; import AzureOpenAiLogo from "../../../../../media/llmprovider/azure.png"; import AnthropicLogo from "../../../../../media/llmprovider/anthropic.png"; import System from "../../../../../models/system"; import PreLoader from "../../../../../components/Preloader"; -import LLMProviderOption from "../../../../../components/LLMProviderOption"; +import LLMProviderOption from "../../../../../components/LLMSelection/LLMProviderOption"; +import OpenAiOptions from "../../../../../components/LLMSelection/OpenAiOptions"; +import AzureAiOptions from "../../../../../components/LLMSelection/AzureAiOptions"; +import AnthropicAiOptions from "../../../../../components/LLMSelection/AnthropicAiOptions"; function LLMSelection({ nextStep, prevStep, currentStep, goToStep }) { const [llmChoice, setLLMChoice] = useState("openai"); @@ -95,158 +97,10 @@ function LLMSelection({ nextStep, prevStep, currentStep, goToStep }) { /> </div> <div className="mt-10 flex flex-wrap gap-4 max-w-[800px]"> - {llmChoice === "openai" && ( - <> - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - API Key - </label> - <input - type="password" - name="OpenAiKey" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="OpenAI API Key" - defaultValue={settings?.OpenAiKey ? "*".repeat(20) : ""} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Chat Model Selection - </label> - <select - name="OpenAiModelPref" - defaultValue={settings?.OpenAiModelPref} - required={true} - className="bg-zinc-900 border border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" - > - {["gpt-3.5-turbo", "gpt-4"].map((model) => { - return ( - <option key={model} value={model}> - {model} - </option> - ); - })} - </select> - </div> - </> - )} - - {llmChoice === "azure" && ( - <> - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Azure Service Endpoint - </label> - <input - type="url" - name="AzureOpenAiEndpoint" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="https://my-azure.openai.azure.com" - defaultValue={settings?.AzureOpenAiEndpoint} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - API Key - </label> - <input - type="password" - name="AzureOpenAiKey" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Azure OpenAI API Key" - defaultValue={ - settings?.AzureOpenAiKey ? "*".repeat(20) : "" - } - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Chat Deployment Name - </label> - <input - type="text" - name="AzureOpenAiModelPref" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Azure OpenAI chat model deployment name" - defaultValue={settings?.AzureOpenAiModelPref} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Embedding Deployment Name - </label> - <input - type="text" - name="AzureOpenAiEmbeddingModelPref" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Azure OpenAI embedding model deployment name" - defaultValue={settings?.AzureOpenAiEmbeddingModelPref} - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - </> - )} - + {llmChoice === "openai" && <OpenAiOptions settings={settings} />} + {llmChoice === "azure" && <AzureAiOptions settings={settings} />} {llmChoice === "anthropic" && ( - <div className="w-full flex flex-col"> - <div className="w-full flex items-center gap-4"> - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Anthropic Claude-2 API Key - </label> - <input - type="text" - name="AnthropicApiKey" - className="bg-zinc-900 text-white placeholder-white placeholder-opacity-60 text-sm rounded-lg focus:border-white block w-full p-2.5" - placeholder="Anthropic Claude-2 API Key" - defaultValue={ - settings?.AnthropicApiKey ? "*".repeat(20) : "" - } - required={true} - autoComplete="off" - spellCheck={false} - /> - </div> - - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-4"> - Chat Model Selection - </label> - <select - name="AnthropicModelPref" - defaultValue={settings?.AnthropicModelPref || "claude-2"} - required={true} - className="bg-zinc-900 border border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" - > - {["claude-2"].map((model) => { - return ( - <option key={model} value={model}> - {model} - </option> - ); - })} - </select> - </div> - </div> - </div> + <AnthropicAiOptions settings={settings} /> )} </div> </div> diff --git a/server/endpoints/system.js b/server/endpoints/system.js index 915d31d19..7e565874d 100644 --- a/server/endpoints/system.js +++ b/server/endpoints/system.js @@ -8,7 +8,7 @@ const { acceptedFileTypes, } = require("../utils/files/documentProcessor"); const { purgeDocument } = require("../utils/files/purgeDocument"); -const { getVectorDbClass } = require("../utils/helpers"); +const { getVectorDbClass, getLLMProvider } = require("../utils/helpers"); const { updateENV, dumpENV } = require("../utils/helpers/updateENV"); const { reqBody, @@ -37,6 +37,7 @@ const { const { Telemetry } = require("../models/telemetry"); const { WelcomeMessages } = require("../models/welcomeMessages"); const { ApiKey } = require("../models/apiKeys"); +const { getCustomModels } = require("../utils/helpers/customModels"); function systemEndpoints(app) { if (!app) return; @@ -627,6 +628,24 @@ function systemEndpoints(app) { response.status(500).end(); } }); + + app.post( + "/system/custom-models", + [validatedRequest], + async (request, response) => { + try { + const { provider, apiKey } = reqBody(request); + const { models, error } = await getCustomModels(provider, apiKey); + return response.status(200).json({ + models, + error, + }); + } catch (error) { + console.error(error); + response.status(500).end(); + } + } + ); } module.exports = { systemEndpoints }; diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index d4d54bd6b..1efaa7466 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -12,9 +12,16 @@ class OpenAiLLM extends OpenAiEmbedder { this.openai = new OpenAIApi(config); } - isValidChatModel(modelName = "") { + async isValidChatCompletionModel(modelName = "") { const validModels = ["gpt-4", "gpt-3.5-turbo"]; - return validModels.includes(modelName); + const isPreset = validModels.some((model) => modelName === model); + if (isPreset) return true; + + const model = await this.openai + .retrieveModel(modelName) + .then((res) => res.data) + .catch(() => null); + return !!model; } constructPrompt({ @@ -70,7 +77,7 @@ class OpenAiLLM extends OpenAiEmbedder { async sendChat(chatHistory = [], prompt, workspace = {}) { const model = process.env.OPEN_MODEL_PREF; - if (!this.isValidChatModel(model)) + if (!(await this.isValidChatCompletionModel(model))) throw new Error( `OpenAI chat: ${model} is not valid for chat completion!` ); @@ -95,7 +102,6 @@ class OpenAiLLM extends OpenAiEmbedder { return res.choices[0].message.content; }) .catch((error) => { - console.log(error); throw new Error( `OpenAI::createChatCompletion failed with: ${error.message}` ); @@ -104,8 +110,13 @@ class OpenAiLLM extends OpenAiEmbedder { return textResponse; } - async getChatCompletion(messages = [], { temperature = 0.7 }) { + async getChatCompletion(messages = null, { temperature = 0.7 }) { const model = process.env.OPEN_MODEL_PREF || "gpt-3.5-turbo"; + if (!(await this.isValidChatCompletionModel(model))) + throw new Error( + `OpenAI chat: ${model} is not valid for chat completion!` + ); + const { data } = await this.openai.createChatCompletion({ model, messages, diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js new file mode 100644 index 000000000..a1b2eda88 --- /dev/null +++ b/server/utils/helpers/customModels.js @@ -0,0 +1,38 @@ +const SUPPORT_CUSTOM_MODELS = ["openai"]; + +async function getCustomModels(provider = "", apiKey = null) { + if (!SUPPORT_CUSTOM_MODELS.includes(provider)) + return { models: [], error: "Invalid provider for custom models" }; + + switch (provider) { + case "openai": + return await openAiModels(apiKey); + default: + return { models: [], error: "Invalid provider for custom models" }; + } +} + +async function openAiModels(apiKey = null) { + const { Configuration, OpenAIApi } = require("openai"); + const config = new Configuration({ + apiKey: apiKey || process.env.OPEN_AI_KEY, + }); + const openai = new OpenAIApi(config); + const models = ( + await openai + .listModels() + .then((res) => res.data.data) + .catch((e) => { + console.error(`OpenAI:listModels`, e.message); + return []; + }) + ).filter( + (model) => !model.owned_by.includes("openai") && model.owned_by !== "system" + ); + + return { models, error: null }; +} + +module.exports = { + getCustomModels, +}; diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 3e6d6429e..9cfb243ff 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -10,7 +10,7 @@ const KEY_MAPPING = { }, OpenAiModelPref: { envKey: "OPEN_MODEL_PREF", - checks: [isNotEmpty, validOpenAIModel], + checks: [isNotEmpty], }, // Azure OpenAI Settings AzureOpenAiEndpoint: { @@ -107,8 +107,6 @@ const KEY_MAPPING = { envKey: "JWT_SECRET", checks: [requiresForceMode], }, - // Not supported yet. - // 'StorageDir': 'STORAGE_DIR', }; function isNotEmpty(input = "") { @@ -138,13 +136,6 @@ function supportedLLM(input = "") { return ["openai", "azure", "anthropic"].includes(input); } -function validOpenAIModel(input = "") { - const validModels = ["gpt-4", "gpt-3.5-turbo"]; - return validModels.includes(input) - ? null - : `Invalid Model type. Must be one of ${validModels.join(", ")}.`; -} - function validAnthropicModel(input = "") { const validModels = ["claude-2"]; return validModels.includes(input) -- GitLab