diff --git a/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx b/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx index 9f6f38f2718aaf887a74fc8f39ce82e10348836c..c9b38ed275e23387f039ec581619a5286f0e306e 100644 --- a/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx +++ b/frontend/src/components/LLMSelection/GroqAiOptions/index.jsx @@ -1,4 +1,10 @@ +import { useState, useEffect } from "react"; +import System from "@/models/system"; + export default function GroqAiOptions({ settings }) { + const [inputValue, setInputValue] = useState(settings?.GroqApiKey); + const [apiKey, setApiKey] = useState(settings?.GroqApiKey); + return ( <div className="flex gap-[36px] mt-1.5"> <div className="flex flex-col w-60"> @@ -8,41 +14,98 @@ export default function GroqAiOptions({ settings }) { <input type="password" name="GroqApiKey" - className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5" + className="border-none bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5" placeholder="Groq API Key" defaultValue={settings?.GroqApiKey ? "*".repeat(20) : ""} required={true} autoComplete="off" spellCheck={false} + onChange={(e) => setInputValue(e.target.value)} + onBlur={() => setApiKey(inputValue)} /> </div> {!settings?.credentialsOnly && ( - <div className="flex flex-col w-60"> - <label className="text-white text-sm font-semibold block mb-3"> - Chat Model Selection - </label> - <select - name="GroqModelPref" - defaultValue={settings?.GroqModelPref || "llama3-8b-8192"} - required={true} - className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" - > - {[ - "mixtral-8x7b-32768", - "llama3-8b-8192", - "llama3-70b-8192", - "gemma-7b-it", - ].map((model) => { + <GroqAIModelSelection settings={settings} apiKey={apiKey} /> + )} + </div> + ); +} + +function GroqAIModelSelection({ apiKey, settings }) { + const [customModels, setCustomModels] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + async function findCustomModels() { + if (!apiKey) { + setCustomModels([]); + setLoading(true); + return; + } + + try { + setLoading(true); + const { models } = await System.customModels("groq", apiKey); + setCustomModels(models || []); + } catch (error) { + console.error("Failed to fetch custom models:", error); + setCustomModels([]); + } finally { + setLoading(false); + } + } + findCustomModels(); + }, [apiKey]); + + if (loading) { + return ( + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-3"> + Chat Model Selection + </label> + <select + name="GroqModelPref" + disabled={true} + className="border-none bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + > + <option disabled={true} selected={true}> + --loading available models-- + </option> + </select> + <p className="text-xs leading-[18px] font-base text-white text-opacity-60 mt-2"> + Enter a valid API key to view all available models for your account. + </p> + </div> + ); + } + + return ( + <div className="flex flex-col w-60"> + <label className="text-white text-sm font-semibold block mb-3"> + Chat Model Selection + </label> + <select + name="GroqModelPref" + required={true} + className="border-none bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" + defaultValue={settings?.GroqModelPref} + > + {customModels.length > 0 && ( + <optgroup label="Available models"> + {customModels.map((model) => { return ( - <option key={model} value={model}> - {model} + <option key={model.id} value={model.id}> + {model.id} </option> ); })} - </select> - </div> - )} + </optgroup> + )} + </select> + <p className="text-xs leading-[18px] font-base text-white text-opacity-60 mt-2"> + Select the GroqAI model you want to use for your conversations. + </p> </div> ); } diff --git a/frontend/src/hooks/useGetProvidersModels.js b/frontend/src/hooks/useGetProvidersModels.js index e118f5e742ba12253bd9392be931c26e260e1e6b..064ad17c88994571bffa09a8b2c64249d662a544 100644 --- a/frontend/src/hooks/useGetProvidersModels.js +++ b/frontend/src/hooks/useGetProvidersModels.js @@ -32,12 +32,7 @@ const PROVIDER_DEFAULT_MODELS = { localai: [], ollama: [], togetherai: [], - groq: [ - "mixtral-8x7b-32768", - "llama3-8b-8192", - "llama3-70b-8192", - "gemma-7b-it", - ], + groq: [], native: [], cohere: [ "command-r", diff --git a/server/utils/AiProviders/groq/index.js b/server/utils/AiProviders/groq/index.js index 067c60c7acb175488db781b10ef37308284b7fbc..ccfc647a0416a81e532e41d7282bbbb9c406c7b9 100644 --- a/server/utils/AiProviders/groq/index.js +++ b/server/utils/AiProviders/groq/index.js @@ -13,7 +13,7 @@ class GroqLLM { apiKey: process.env.GROQ_API_KEY, }); this.model = - modelPreference || process.env.GROQ_MODEL_PREF || "llama3-8b-8192"; + modelPreference || process.env.GROQ_MODEL_PREF || "llama-3.1-8b-instant"; this.limits = { history: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15, @@ -42,34 +42,24 @@ class GroqLLM { promptWindowLimit() { switch (this.model) { - case "mixtral-8x7b-32768": - return 32_768; - case "llama3-8b-8192": - return 8192; - case "llama3-70b-8192": - return 8192; + case "gemma2-9b-it": case "gemma-7b-it": + case "llama3-70b-8192": + case "llama3-8b-8192": return 8192; + case "llama-3.1-70b-versatile": + case "llama-3.1-8b-instant": + case "llama-3.1-8b-instant": + return 131072; + case "mixtral-8x7b-32768": + return 32768; default: return 8192; } } async isValidChatCompletionModel(modelName = "") { - const validModels = [ - "mixtral-8x7b-32768", - "llama3-8b-8192", - "llama3-70b-8192", - "gemma-7b-it", - ]; - const isPreset = validModels.some((model) => modelName === model); - if (isPreset) return true; - - const model = await this.openai.models - .retrieve(modelName) - .then((modelObj) => modelObj) - .catch(() => null); - return !!model; + return !!modelName; // name just needs to exist } constructPrompt({ diff --git a/server/utils/helpers/customModels.js b/server/utils/helpers/customModels.js index 31a3eb2c029c140e23f99b832e71e3e0f18f53cd..27afa150f5a1f6c7b1015bed21017e016f801bb2 100644 --- a/server/utils/helpers/customModels.js +++ b/server/utils/helpers/customModels.js @@ -1,7 +1,4 @@ -const { - OpenRouterLLM, - fetchOpenRouterModels, -} = require("../AiProviders/openRouter"); +const { fetchOpenRouterModels } = require("../AiProviders/openRouter"); const { perplexityModels } = require("../AiProviders/perplexity"); const { togetherAiModels } = require("../AiProviders/togetherAi"); const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs"); @@ -18,6 +15,7 @@ const SUPPORT_CUSTOM_MODELS = [ "koboldcpp", "litellm", "elevenlabs-tts", + "groq", ]; async function getCustomModels(provider = "", apiKey = null, basePath = null) { @@ -49,6 +47,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) { return await liteLLMModels(basePath, apiKey); case "elevenlabs-tts": return await getElevenLabsModels(apiKey); + case "groq": + return await getGroqAiModels(apiKey); default: return { models: [], error: "Invalid provider for custom models" }; } @@ -167,6 +167,33 @@ async function localAIModels(basePath = null, apiKey = null) { return { models, error: null }; } +async function getGroqAiModels(_apiKey = null) { + const { OpenAI: OpenAIApi } = require("openai"); + const apiKey = + _apiKey === true + ? process.env.GROQ_API_KEY + : _apiKey || process.env.GROQ_API_KEY || null; + const openai = new OpenAIApi({ + baseURL: "https://api.groq.com/openai/v1", + apiKey, + }); + const models = ( + await openai.models + .list() + .then((results) => results.data) + .catch((e) => { + console.error(`GroqAi:listModels`, e.message); + return []; + }) + ).filter( + (model) => !model.id.includes("whisper") && !model.id.includes("tool-use") + ); + + // Api Key was successful so lets save it for future uses + if (models.length > 0 && !!apiKey) process.env.GROQ_API_KEY = apiKey; + return { models, error: null }; +} + async function liteLLMModels(basePath = null, apiKey = null) { const { OpenAI: OpenAIApi } = require("openai"); const openai = new OpenAIApi({