Skip to content
Snippets Groups Projects
Unverified Commit b082c8e4 authored by Timothy Carambat's avatar Timothy Carambat Committed by GitHub
Browse files

Add support for gemini authenticated models endpoint (#2868)

* Add support for gemini authenticated models endpoint
add customModels entry
add un-authed fallback to default listing
separate models by expiermental status
resolves #2866

* add back improved logic for apiVersion decision making
parent 71cd5e5b
No related branches found
No related tags found
No related merge requests found
import System from "@/models/system";
import { useEffect, useState } from "react";
export default function GeminiLLMOptions({ settings }) { export default function GeminiLLMOptions({ settings }) {
const [inputValue, setInputValue] = useState(settings?.GeminiLLMApiKey);
const [geminiApiKey, setGeminiApiKey] = useState(settings?.GeminiLLMApiKey);
return ( return (
<div className="w-full flex flex-col"> <div className="w-full flex flex-col">
<div className="w-full flex items-center gap-[36px] mt-1.5"> <div className="w-full flex items-center gap-[36px] mt-1.5">
...@@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) { ...@@ -15,56 +21,14 @@ export default function GeminiLLMOptions({ settings }) {
required={true} required={true}
autoComplete="off" autoComplete="off"
spellCheck={false} spellCheck={false}
onChange={(e) => setInputValue(e.target.value)}
onBlur={() => setGeminiApiKey(inputValue)}
/> />
</div> </div>
{!settings?.credentialsOnly && ( {!settings?.credentialsOnly && (
<> <>
<div className="flex flex-col w-60"> <GeminiModelSelection apiKey={geminiApiKey} settings={settings} />
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
defaultValue={settings?.GeminiLLMModelPref || "gemini-pro"}
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<optgroup label="Stable Models">
{[
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
<optgroup label="Experimental Models">
{[
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].map((model) => {
return (
<option key={model} value={model}>
{model}
</option>
);
})}
</optgroup>
</select>
</div>
<div className="flex flex-col w-60"> <div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3"> <label className="text-white text-sm font-semibold block mb-3">
Safety Setting Safety Setting
...@@ -91,3 +55,79 @@ export default function GeminiLLMOptions({ settings }) { ...@@ -91,3 +55,79 @@ export default function GeminiLLMOptions({ settings }) {
</div> </div>
); );
} }
function GeminiModelSelection({ apiKey, settings }) {
const [groupedModels, setGroupedModels] = useState({});
const [loading, setLoading] = useState(true);
useEffect(() => {
async function findCustomModels() {
setLoading(true);
const { models } = await System.customModels("gemini", apiKey);
if (models?.length > 0) {
const modelsByOrganization = models.reduce((acc, model) => {
acc[model.experimental ? "Experimental" : "Stable"] =
acc[model.experimental ? "Experimental" : "Stable"] || [];
acc[model.experimental ? "Experimental" : "Stable"].push(model);
return acc;
}, {});
setGroupedModels(modelsByOrganization);
}
setLoading(false);
}
findCustomModels();
}, [apiKey]);
if (loading) {
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
disabled={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<option disabled={true} selected={true}>
-- loading available models --
</option>
</select>
</div>
);
}
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GeminiLLMModelPref"
required={true}
className="border-none bg-theme-settings-input-bg border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
{Object.keys(groupedModels)
.sort((a, b) => {
if (a === "Stable") return -1;
if (b === "Stable") return 1;
return a.localeCompare(b);
})
.map((organization) => (
<optgroup key={organization} label={organization}>
{groupedModels[organization].map((model) => (
<option
key={model.id}
value={model.id}
selected={settings?.GeminiLLMModelPref === model.id}
>
{model.name}
</option>
))}
</optgroup>
))}
</select>
</div>
);
}
const { MODEL_MAP } = require("../modelMap");
const stableModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
];
const experimentalModels = [
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];
// There are some models that are only available in the v1beta API
// and some models that are only available in the v1 API
// generally, v1beta models have `exp` in the name, but not always
// so we check for both against a static list as well.
const v1BetaModels = ["gemini-1.5-pro-latest", "gemini-1.5-flash-latest"];
const defaultGeminiModels = [
...stableModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: false,
})),
...experimentalModels.map((model) => ({
id: model,
name: model,
contextWindow: MODEL_MAP.gemini[model],
experimental: true,
})),
];
module.exports = {
defaultGeminiModels,
v1BetaModels,
};
...@@ -7,6 +7,7 @@ const { ...@@ -7,6 +7,7 @@ const {
clientAbortedHandler, clientAbortedHandler,
} = require("../../helpers/chat/responses"); } = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap"); const { MODEL_MAP } = require("../modelMap");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModals");
class GeminiLLM { class GeminiLLM {
constructor(embedder = null, modelPreference = null) { constructor(embedder = null, modelPreference = null) {
...@@ -21,22 +22,17 @@ class GeminiLLM { ...@@ -21,22 +22,17 @@ class GeminiLLM {
this.gemini = genAI.getGenerativeModel( this.gemini = genAI.getGenerativeModel(
{ model: this.model }, { model: this.model },
{ {
// Gemini-1.5-pro-* and Gemini-1.5-flash are only available on the v1beta API. apiVersion:
apiVersion: [ /**
"gemini-1.5-pro-latest", * There are some models that are only available in the v1beta API
"gemini-1.5-flash-latest", * and some models that are only available in the v1 API
"gemini-1.5-pro-exp-0801", * generally, v1beta models have `exp` in the name, but not always
"gemini-1.5-pro-exp-0827", * so we check for both against a static list as well.
"gemini-1.5-flash-exp-0827", * @see {v1BetaModels}
"gemini-1.5-flash-8b-exp-0827", */
"gemini-exp-1114", this.model.includes("exp") || v1BetaModels.includes(this.model)
"gemini-exp-1121", ? "v1beta"
"gemini-exp-1206", : "v1",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
].includes(this.model)
? "v1beta"
: "v1",
} }
); );
this.limits = { this.limits = {
...@@ -48,6 +44,11 @@ class GeminiLLM { ...@@ -48,6 +44,11 @@ class GeminiLLM {
this.embedder = embedder ?? new NativeEmbedder(); this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold(); this.safetyThreshold = this.#fetchSafetyThreshold();
this.#log(`Initialized with model: ${this.model}`);
}
#log(text, ...args) {
console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
} }
#appendContext(contextTexts = []) { #appendContext(contextTexts = []) {
...@@ -109,25 +110,63 @@ class GeminiLLM { ...@@ -109,25 +110,63 @@ class GeminiLLM {
return MODEL_MAP.gemini[this.model] ?? 30_720; return MODEL_MAP.gemini[this.model] ?? 30_720;
} }
isValidChatCompletionModel(modelName = "") { /**
const validModels = [ * Fetches Gemini models from the Google Generative AI API
"gemini-pro", * @param {string} apiKey - The API key to use for the request
"gemini-1.0-pro", * @param {number} limit - The maximum number of models to fetch
"gemini-1.5-pro-latest", * @param {string} pageToken - The page token to use for pagination
"gemini-1.5-flash-latest", * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
"gemini-1.5-pro-exp-0801", */
"gemini-1.5-pro-exp-0827", static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
"gemini-1.5-flash-exp-0827", const url = new URL(
"gemini-1.5-flash-8b-exp-0827", "https://generativelanguage.googleapis.com/v1beta/models"
"gemini-exp-1114", );
"gemini-exp-1121", url.searchParams.set("pageSize", limit);
"gemini-exp-1206", url.searchParams.set("key", apiKey);
"learnlm-1.5-pro-experimental", if (pageToken) url.searchParams.set("pageToken", pageToken);
"gemini-2.0-flash-exp",
]; return fetch(url.toString(), {
return validModels.includes(modelName); method: "GET",
headers: { "Content-Type": "application/json" },
})
.then((res) => res.json())
.then((data) => {
if (data.error) throw new Error(data.error.message);
return data.models ?? [];
})
.then((models) =>
models
.filter(
(model) => !model.displayName.toLowerCase().includes("tuning")
)
.filter((model) =>
model.supportedGenerationMethods.includes("generateContent")
) // Only generateContent is supported
.map((model) => {
return {
id: model.name.split("/").pop(),
name: model.displayName,
contextWindow: model.inputTokenLimit,
experimental: model.name.includes("exp"),
};
})
)
.catch((e) => {
console.error(`Gemini:getGeminiModels`, e.message);
return defaultGeminiModels;
});
} }
/**
* Checks if a model is valid for chat completion (unused)
* @deprecated
* @param {string} modelName - The name of the model to check
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
*/
async isValidChatCompletionModel(modelName = "") {
const models = await this.fetchModels(true);
return models.some((model) => model.id === modelName);
}
/** /**
* Generates appropriate content array for a message + attachments. * Generates appropriate content array for a message + attachments.
* @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
...@@ -218,11 +257,6 @@ class GeminiLLM { ...@@ -218,11 +257,6 @@ class GeminiLLM {
} }
async getChatCompletion(messages = [], _opts = {}) { async getChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);
const prompt = messages.find( const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT" (chat) => chat.role === "USER_PROMPT"
)?.content; )?.content;
...@@ -256,11 +290,6 @@ class GeminiLLM { ...@@ -256,11 +290,6 @@ class GeminiLLM {
} }
async streamGetChatCompletion(messages = [], _opts = {}) { async streamGetChatCompletion(messages = [], _opts = {}) {
if (!this.isValidChatCompletionModel(this.model))
throw new Error(
`Gemini chat: ${this.model} is not valid for chat completion!`
);
const prompt = messages.find( const prompt = messages.find(
(chat) => chat.role === "USER_PROMPT" (chat) => chat.role === "USER_PROMPT"
)?.content; )?.content;
......
...@@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs"); ...@@ -7,6 +7,7 @@ const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs");
const { fetchNovitaModels } = require("../AiProviders/novita"); const { fetchNovitaModels } = require("../AiProviders/novita");
const { parseLMStudioBasePath } = require("../AiProviders/lmStudio"); const { parseLMStudioBasePath } = require("../AiProviders/lmStudio");
const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim"); const { parseNvidiaNimBasePath } = require("../AiProviders/nvidiaNim");
const { GeminiLLM } = require("../AiProviders/gemini");
const SUPPORT_CUSTOM_MODELS = [ const SUPPORT_CUSTOM_MODELS = [
"openai", "openai",
...@@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [ ...@@ -28,6 +29,7 @@ const SUPPORT_CUSTOM_MODELS = [
"apipie", "apipie",
"novita", "novita",
"xai", "xai",
"gemini",
]; ];
async function getCustomModels(provider = "", apiKey = null, basePath = null) { async function getCustomModels(provider = "", apiKey = null, basePath = null) {
...@@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) { ...@@ -73,6 +75,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
return await getXAIModels(apiKey); return await getXAIModels(apiKey);
case "nvidia-nim": case "nvidia-nim":
return await getNvidiaNimModels(basePath); return await getNvidiaNimModels(basePath);
case "gemini":
return await getGeminiModels(apiKey);
default: default:
return { models: [], error: "Invalid provider for custom models" }; return { models: [], error: "Invalid provider for custom models" };
} }
...@@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) { ...@@ -572,6 +576,17 @@ async function getNvidiaNimModels(basePath = null) {
} }
} }
async function getGeminiModels(_apiKey = null) {
const apiKey =
_apiKey === true
? process.env.GEMINI_API_KEY
: _apiKey || process.env.GEMINI_API_KEY || null;
const models = await GeminiLLM.fetchModels(apiKey);
// Api Key was successful so lets save it for future uses
if (models.length > 0 && !!apiKey) process.env.GEMINI_API_KEY = apiKey;
return { models, error: null };
}
module.exports = { module.exports = {
getCustomModels, getCustomModels,
}; };
...@@ -52,7 +52,7 @@ const KEY_MAPPING = { ...@@ -52,7 +52,7 @@ const KEY_MAPPING = {
}, },
GeminiLLMModelPref: { GeminiLLMModelPref: {
envKey: "GEMINI_LLM_MODEL_PREF", envKey: "GEMINI_LLM_MODEL_PREF",
checks: [isNotEmpty, validGeminiModel], checks: [isNotEmpty],
}, },
GeminiSafetySetting: { GeminiSafetySetting: {
envKey: "GEMINI_SAFETY_SETTING", envKey: "GEMINI_SAFETY_SETTING",
...@@ -724,27 +724,6 @@ function supportedTranscriptionProvider(input = "") { ...@@ -724,27 +724,6 @@ function supportedTranscriptionProvider(input = "") {
: `${input} is not a valid transcription model provider.`; : `${input} is not a valid transcription model provider.`;
} }
function validGeminiModel(input = "") {
const validModels = [
"gemini-pro",
"gemini-1.0-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-flash-latest",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-8b-exp-0827",
"gemini-exp-1114",
"gemini-exp-1121",
"gemini-exp-1206",
"learnlm-1.5-pro-experimental",
"gemini-2.0-flash-exp",
];
return validModels.includes(input)
? null
: `Invalid Model type. Must be one of ${validModels.join(", ")}.`;
}
function validGeminiSafetySetting(input = "") { function validGeminiSafetySetting(input = "") {
const validModes = [ const validModes = [
"BLOCK_NONE", "BLOCK_NONE",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment