diff --git a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx index 79c72fffa6b17e65cc5a35ffb89c7778a6661705..105487bf3b299ab561fcb20d092c4e00647c2d14 100644 --- a/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx +++ b/frontend/src/components/LLMSelection/GeminiLLMOptions/index.jsx @@ -1,4 +1,10 @@ +import System from "@/models/system"; +import { useEffect, useState } from "react"; + export default function GeminiLLMOptions({ settings }) { + const [inputValue, setInputValue] = useState(settings?.GeminiLLMApiKey); + const [geminiApiKey, setGeminiApiKey] = useState(settings?.GeminiLLMApiKey); + return ( <div className="w-full flex flex-col"> <div className="w-full flex items-center gap-[36px] mt-1.5"> @@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) { required={true} autoComplete="off" spellCheck={false} + onChange={(e) => setInputValue(e.target.value)} + onBlur={() => setGeminiApiKey(inputValue)} /> </div> {!settings?.credentialsOnly && ( <> - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-3"> - Chat Model Selection - </label> - <select - name="GeminiLLMModelPref" - defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"} - required={true} - className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" - > - <optgroup label="Stable Models"> - {[ - "gemini-pro", - "gemini-1.0-pro", - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - ].map((model) => { - return ( - <option key={model} value={model}> - {model} - </option> - ); - })} - </optgroup> - <optgroup label="Experimental Models"> - {[ - "gemini-1.5-pro-exp-0801", - "gemini-1.5-pro-exp-0827", - "gemini-1.5-flash-exp-0827", - "gemini-1.5-flash-8b-exp-0827", - "gemini-exp-1114", - "gemini-exp-1121", - "gemini-exp-1206", - "learnlm-1.5-pro-experimental", - "gemini-2.0-flash-exp", - ].map((model) => { - return ( - <option key={model} value={model}> - {model} - </option> - ); - })} - </optgroup> - </select> - </div> + <GeminiModelSelection apiKey={geminiApiKey} settings={settings} /> <div className="flex flex-col w-60"> <label className="text-white text-sm font-semibold block mb-3"> Safety Setting @@ -91,3 +55,79 @@ export default function GeminiLLMOptions({ settings }) { </div> ); } + +function GeminiModelSelection({ apiKey, settings }) { + const [groupedModels, setGroupedModels] = useState({}); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function findCustomModels() { + setLoading(true); + const { models } = await System.customModels("gemini", apiKey); + + if (models?.length > 0) { + const modelsByOrganization = models.reduce((acc, model) => { + acc[model.experimental ? "Experimental" : "Stable"] = + acc[model.experimental ? "Experimental" : "Stable"] || []; + acc[model.experimental ? "Experimental" : "Stable"].push(model); + return acc; + }, {}); + setGroupedModels(modelsByOrganization); + } + setLoading(false); + } + findCustomModels(); + }, [apiKey]); + + if (loading) { + return ( + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-3"> + Chat Model Selection + </label> + <select + name="GeminiLLMModelPref" + disabled={true} + className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + <option disabled={true} selected={true}> + -- loading available models -- + </option> + </select> + </div> + ); + } + + return ( + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-3"> + Chat Model Selection + </label> + <select + name="GeminiLLMModelPref" + required={true} + className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + {Object.keys(groupedModels) + .sort((a, b) => { + if (a === "Stable") return -1; + if (b === "Stable") return 1; + return a.localeCompare(b); + }) + .map((organization) => ( + <optgroup key={organization} label={organization}> + {groupedModels[organization].map((model) => ( + <option + key={model.id} + value={model.id} + selected={settings?.GeminiLLMModelPref === model.id} + > + {model.name} + </option> + ))} + </optgroup> + ))} + </select> + </div> + ); +} diff --git a/server/utils/AiProviders/gemini/defaultModals.js b/server/utils/AiProviders/gemini/defaultModals.js new file mode 100644 index 0000000000000000000000000000000000000000..303a0aafff364ace4ef75ee219dd216a468190f0 --- /dev/null +++ b/server/utils/AiProviders/gemini/defaultModals.js @@ -0,0 +1,46 @@ +const { MODEL_MAP } = require("../modelMap"); + +const stableModels = [ + "gemini-pro", + "gemini-1.0-pro", + "gemini-1.5-pro-latest", + "gemini-1.5-flash-latest", +]; + +const experimentalModels = [ + "gemini-1.5-pro-exp-0801", + "gemini-1.5-pro-exp-0827", + "gemini-1.5-flash-exp-0827", + "gemini-1.5-flash-8b-exp-0827", + "gemini-exp-1114", + "gemini-exp-1121", + "gemini-exp-1206", + "learnlm-1.5-pro-experimental", + "gemini-2.0-flash-exp", +]; + +// There are some models that are only available in the v1beta API +// and some models that are only available in the v1 API +// generally, v1beta models have `exp` in the name, but not always +// so we check for both against a static list as well. +const v1BetaModels = ["gemini-1.5-pro-latest", "gemini-1.5-flash-latest"]; + +const defaultGeminiModels = [ + ...stableModels.map((model) => ({ + id: model, + name: model, + contextWindow: MODEL_MAP.gemini[model], + experimental: false, + })), + ...experimentalModels.map((model) => ({ + id: model, + name: model, + contextWindow: MODEL_MAP.gemini[model], + experimental: true, + })), +]; + +module.exports = { + defaultGeminiModels, + v1BetaModels, +}; diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index f658b3c5f11651a264f555727b4d010794850815..3554a51fcfc81c1cbc81bc1f05035ff2800a40a3 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -7,6 +7,7 @@ const { clientAbortedHandler, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); +const { defaultGeminiModels, v1BetaModels } = require("./defaultModals"); class GeminiLLM { constructor(embedder = null, modelPreference = null) { @@ -21,22 +22,17 @@ class GeminiLLM { this.gemini = genAI.getGenerativeModel( { model: this.model }, { - // Gemini-1.5-pro-* and Gemini-1.5-flash are only available on the v1beta API. - apiVersion: [ - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - "gemini-1.5-pro-exp-0801", - "gemini-1.5-pro-exp-0827", - "gemini-1.5-flash-exp-0827", - "gemini-1.5-flash-8b-exp-0827", - "gemini-exp-1114", - "gemini-exp-1121", - "gemini-exp-1206", - "learnlm-1.5-pro-experimental", - "gemini-2.0-flash-exp", - ].includes(this.model) - ? "v1beta" - : "v1", + apiVersion: + /** + * There are some models that are only available in the v1beta API + * and some models that are only available in the v1 API + * generally, v1beta models have `exp` in the name, but not always + * so we check for both against a static list as well. + * @see {v1BetaModels} + */ + this.model.includes("exp") || v1BetaModels.includes(this.model) + ? "v1beta" + : "v1", } ); this.limits = { @@ -48,6 +44,11 @@ class GeminiLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; // not used for Gemini this.safetyThreshold = this.#fetchSafetyThreshold(); + this.#log(`Initialized with model: ${this.model}`); + } + + #log(text, ...args) { + console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args); } #appendContext(contextTexts = []) { @@ -109,25 +110,63 @@ class GeminiLLM { return MODEL_MAP.gemini[this.model] ?? 30_720; } - isValidChatCompletionModel(modelName = "") { - const validModels = [ - "gemini-pro", - "gemini-1.0-pro", - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - "gemini-1.5-pro-exp-0801", - "gemini-1.5-pro-exp-0827", - "gemini-1.5-flash-exp-0827", - "gemini-1.5-flash-8b-exp-0827", - "gemini-exp-1114", - "gemini-exp-1121", - "gemini-exp-1206", - "learnlm-1.5-pro-experimental", - "gemini-2.0-flash-exp", - ]; - return validModels.includes(modelName); + /** + * Fetches Gemini models from the Google Generative AI API + * @param {string} apiKey - The API key to use for the request + * @param {number} limit - The maximum number of models to fetch + * @param {string} pageToken - The page token to use for pagination + * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models + */ + static async fetchModels(apiKey, limit = 1_000, pageToken = null) { + const url = new URL( + "https://generativelanguage.googleapis.com/v1beta/models" + ); + url.searchParams.set("pageSize", limit); + url.searchParams.set("key", apiKey); + if (pageToken) url.searchParams.set("pageToken", pageToken); + + return fetch(url.toString(), { + method: "GET", + headers: { "Content-Type": "application/json" }, + }) + .then((res) => res.json()) + .then((data) => { + if (data.error) throw new Error(data.error.message); + return data.models ?? []; + }) + .then((models) => + models + .filter( + (model) => !model.displayName.toLowerCase().includes("tuning") + ) + .filter((model) => + model.supportedGenerationMethods.includes("generateContent") + ) // Only generateContent is supported + .map((model) => { + return { + id: model.name.split("/").pop(), + name: model.displayName, + contextWindow: model.inputTokenLimit, + experimental: model.name.includes("exp"), + }; + }) + ) + .catch((e) => { + console.error(`Gemini:getGeminiModels`, e.message); + return defaultGeminiModels; + }); } + /** + * Checks if a model is valid for chat completion (unused) + * @deprecated + * @param {string} modelName - The name of the model to check + * @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid + */ + async isValidChatCompletionModel(modelName = "") { + const models = await this.fetchModels(true); + return models.some((model) => model.id === modelName); + } /** * Generates appropriate content array for a message + attachments. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} @@ -218,11 +257,6 @@ class GeminiLLM { } async getChatCompletion(messages = [], _opts = {}) { - if (!this.isValidChatCompletionModel(this.model)) - throw new Error( - `Gemini chat: ${this.model} is not valid for chat completion!` - ); - const prompt = messages.find( (chat) => chat.role === "USER_PROMPT" )?.content; @@ -256,11 +290,6 @@ class GeminiLLM { } async streamGetChatCompletion(messages = [], _opts = {}) { - if (!this.isValidChatCompletionModel(this.model)) - throw new Error( - `Gemini chat: ${this.model} is not valid for chat completion!` - ); - const prompt = messages.find( (chat) => chat.role === "USER_PROMPT" )?.content; diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index a763635fb52a746361c3a199e4c6da98e3d5d912..7adb276261762c4b0b238e1a4926b2fa6bc5b3bc 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs"); const { fetchNovitaModels } = require("../AiProviders/novita"); const { parseLMStudioBasePath } = require("../AiProviders/lmStudio"); const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim"); +const { GeminiLLM } = require("../AiProviders/gemini"); const SUPPORT_CUSTOM_MODELS = [ "openai", @@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [ "apipie", "novita", "xai", + "gemini", ]; async function getCustomModels(provider = "", apiKey = null, basePath = null) { @@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) { return await getXAIModels(apiKey); case "nvidia-nim": return await getNvidiaNimModels(basePath); + case "gemini": + return await getGeminiModels(apiKey); default: return { models: [], error: "Invalid provider for custom models" }; } @@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) { } } +async function getGeminiModels(_apiKey = null) { + const apiKey = + _apiKey === true + ? process.env.GEMINI_API_KEY + : _apiKey || process.env.GEMINI_API_KEY || null; + const models = await GeminiLLM.fetchModels(apiKey); + // Api Key was successful so lets save it for future uses + if (models.length > 0 && !!apiKey) process.env.GEMINI_API_KEY = apiKey; + return { models, error: null }; +} + module.exports = { getCustomModels, }; diff --git a/server/utils/helpers/updateENV.js b/server/utils/helpers/updateENV.js index 948703dca218da76166c943b0569ca34a98ca36a..da30b6ee0dde1ea295dc6a288a567039dc9a55bd 100644 --- a/server/utils/helpers/updateENV.js +++ b/server/utils/helpers/updateENV.js @@ -52,7 +52,7 @@ const KEY_MAPPING = { }, GeminiLLMModelPref: { envKey: "GEMINI_LLM_MODEL_PREF", - checks: [isNotEmpty, validGeminiModel], + checks: [isNotEmpty], }, GeminiSafetySetting: { envKey: "GEMINI_SAFETY_SETTING", @@ -724,27 +724,6 @@ function supportedTranscriptionProvider(input = "") { : `${input} is not a valid transcription model provider.`; } -function validGeminiModel(input = "") { - const validModels = [ - "gemini-pro", - "gemini-1.0-pro", - "gemini-1.5-pro-latest", - "gemini-1.5-flash-latest", - "gemini-1.5-pro-exp-0801", - "gemini-1.5-pro-exp-0827", - "gemini-1.5-flash-exp-0827", - "gemini-1.5-flash-8b-exp-0827", - "gemini-exp-1114", - "gemini-exp-1121", - "gemini-exp-1206", - "learnlm-1.5-pro-experimental", - "gemini-2.0-flash-exp", - ]; - return validModels.includes(input) - ? null - : `Invalid Model type. Must be one of ${validModels.join(", ")}.`; -} - function validGeminiSafetySetting(input = "") { const validModes = [ "BLOCK_NONE",