Skip to content
Snippets Groups Projects
Unverified Commit 61e214aa authored by Timothy Carambat's avatar Timothy Carambat Committed by GitHub
Browse files

Add support for Groq /models endpoint (#1957)

* Add support for Groq /models endpoint

* linting
parent 23de85a3
Branches
Tags
No related merge requests found
import { useState, useEffect } from "react";
import System from "@/models/system";
export default function GroqAiOptions({ settings }) { export default function GroqAiOptions({ settings }) {
const [inputValue, setInputValue] = useState(settings?.GroqApiKey);
const [apiKey, setApiKey] = useState(settings?.GroqApiKey);
return ( return (
<div className="flex gap-[36px] mt-1.5"> <div className="flex gap-[36px] mt-1.5">
<div className="flex flex-col w-60"> <div className="flex flex-col w-60">
...@@ -8,41 +14,98 @@ export default function GroqAiOptions({ settings }) { ...@@ -8,41 +14,98 @@ export default function GroqAiOptions({ settings }) {
<input <input
type="password" type="password"
name="GroqApiKey" name="GroqApiKey"
className="bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5" className="border-none bg-zinc-900 text-white placeholder:text-white/20 text-sm rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5"
placeholder="Groq API Key" placeholder="Groq API Key"
defaultValue={settings?.GroqApiKey ? "*".repeat(20) : ""} defaultValue={settings?.GroqApiKey ? "*".repeat(20) : ""}
required={true} required={true}
autoComplete="off" autoComplete="off"
spellCheck={false} spellCheck={false}
onChange={(e) => setInputValue(e.target.value)}
onBlur={() => setApiKey(inputValue)}
/> />
</div> </div>
{!settings?.credentialsOnly && ( {!settings?.credentialsOnly && (
<div className="flex flex-col w-60"> <GroqAIModelSelection settings={settings} apiKey={apiKey} />
<label className="text-white text-sm font-semibold block mb-3"> )}
Chat Model Selection </div>
</label> );
<select }
name="GroqModelPref"
defaultValue={settings?.GroqModelPref || "llama3-8b-8192"} function GroqAIModelSelection({ apiKey, settings }) {
required={true} const [customModels, setCustomModels] = useState([]);
className="bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5" const [loading, setLoading] = useState(true);
>
{[ useEffect(() => {
"mixtral-8x7b-32768", async function findCustomModels() {
"llama3-8b-8192", if (!apiKey) {
"llama3-70b-8192", setCustomModels([]);
"gemma-7b-it", setLoading(true);
].map((model) => { return;
}
try {
setLoading(true);
const { models } = await System.customModels("groq", apiKey);
setCustomModels(models || []);
} catch (error) {
console.error("Failed to fetch custom models:", error);
setCustomModels([]);
} finally {
setLoading(false);
}
}
findCustomModels();
}, [apiKey]);
if (loading) {
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GroqModelPref"
disabled={true}
className="border-none bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
>
<option disabled={true} selected={true}>
--loading available models--
</option>
</select>
<p className="text-xs leading-[18px] font-base text-white text-opacity-60 mt-2">
Enter a valid API key to view all available models for your account.
</p>
</div>
);
}
return (
<div className="flex flex-col w-60">
<label className="text-white text-sm font-semibold block mb-3">
Chat Model Selection
</label>
<select
name="GroqModelPref"
required={true}
className="border-none bg-zinc-900 border-gray-500 text-white text-sm rounded-lg block w-full p-2.5"
defaultValue={settings?.GroqModelPref}
>
{customModels.length > 0 && (
<optgroup label="Available models">
{customModels.map((model) => {
return ( return (
<option key={model} value={model}> <option key={model.id} value={model.id}>
{model} {model.id}
</option> </option>
); );
})} })}
</select> </optgroup>
</div> )}
)} </select>
<p className="text-xs leading-[18px] font-base text-white text-opacity-60 mt-2">
Select the GroqAI model you want to use for your conversations.
</p>
</div> </div>
); );
} }
...@@ -32,12 +32,7 @@ const PROVIDER_DEFAULT_MODELS = { ...@@ -32,12 +32,7 @@ const PROVIDER_DEFAULT_MODELS = {
localai: [], localai: [],
ollama: [], ollama: [],
togetherai: [], togetherai: [],
groq: [ groq: [],
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"gemma-7b-it",
],
native: [], native: [],
cohere: [ cohere: [
"command-r", "command-r",
......
...@@ -13,7 +13,7 @@ class GroqLLM { ...@@ -13,7 +13,7 @@ class GroqLLM {
apiKey: process.env.GROQ_API_KEY, apiKey: process.env.GROQ_API_KEY,
}); });
this.model = this.model =
modelPreference || process.env.GROQ_MODEL_PREF || "llama3-8b-8192"; modelPreference || process.env.GROQ_MODEL_PREF || "llama-3.1-8b-instant";
this.limits = { this.limits = {
history: this.promptWindowLimit() * 0.15, history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15, system: this.promptWindowLimit() * 0.15,
...@@ -42,34 +42,24 @@ class GroqLLM { ...@@ -42,34 +42,24 @@ class GroqLLM {
promptWindowLimit() { promptWindowLimit() {
switch (this.model) { switch (this.model) {
case "mixtral-8x7b-32768": case "gemma2-9b-it":
return 32_768;
case "llama3-8b-8192":
return 8192;
case "llama3-70b-8192":
return 8192;
case "gemma-7b-it": case "gemma-7b-it":
case "llama3-70b-8192":
case "llama3-8b-8192":
return 8192; return 8192;
case "llama-3.1-70b-versatile":
case "llama-3.1-8b-instant":
case "llama-3.1-8b-instant":
return 131072;
case "mixtral-8x7b-32768":
return 32768;
default: default:
return 8192; return 8192;
} }
} }
async isValidChatCompletionModel(modelName = "") { async isValidChatCompletionModel(modelName = "") {
const validModels = [ return !!modelName; // name just needs to exist
"mixtral-8x7b-32768",
"llama3-8b-8192",
"llama3-70b-8192",
"gemma-7b-it",
];
const isPreset = validModels.some((model) => modelName === model);
if (isPreset) return true;
const model = await this.openai.models
.retrieve(modelName)
.then((modelObj) => modelObj)
.catch(() => null);
return !!model;
} }
constructPrompt({ constructPrompt({
......
const { const { fetchOpenRouterModels } = require("../AiProviders/openRouter");
OpenRouterLLM,
fetchOpenRouterModels,
} = require("../AiProviders/openRouter");
const { perplexityModels } = require("../AiProviders/perplexity"); const { perplexityModels } = require("../AiProviders/perplexity");
const { togetherAiModels } = require("../AiProviders/togetherAi"); const { togetherAiModels } = require("../AiProviders/togetherAi");
const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs"); const { ElevenLabsTTS } = require("../TextToSpeech/elevenLabs");
...@@ -18,6 +15,7 @@ const SUPPORT_CUSTOM_MODELS = [ ...@@ -18,6 +15,7 @@ const SUPPORT_CUSTOM_MODELS = [
"koboldcpp", "koboldcpp",
"litellm", "litellm",
"elevenlabs-tts", "elevenlabs-tts",
"groq",
]; ];
async function getCustomModels(provider = "", apiKey = null, basePath = null) { async function getCustomModels(provider = "", apiKey = null, basePath = null) {
...@@ -49,6 +47,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) { ...@@ -49,6 +47,8 @@ async function getCustomModels(provider = "", apiKey = null, basePath = null) {
return await liteLLMModels(basePath, apiKey); return await liteLLMModels(basePath, apiKey);
case "elevenlabs-tts": case "elevenlabs-tts":
return await getElevenLabsModels(apiKey); return await getElevenLabsModels(apiKey);
case "groq":
return await getGroqAiModels(apiKey);
default: default:
return { models: [], error: "Invalid provider for custom models" }; return { models: [], error: "Invalid provider for custom models" };
} }
...@@ -167,6 +167,33 @@ async function localAIModels(basePath = null, apiKey = null) { ...@@ -167,6 +167,33 @@ async function localAIModels(basePath = null, apiKey = null) {
return { models, error: null }; return { models, error: null };
} }
async function getGroqAiModels(_apiKey = null) {
const { OpenAI: OpenAIApi } = require("openai");
const apiKey =
_apiKey === true
? process.env.GROQ_API_KEY
: _apiKey || process.env.GROQ_API_KEY || null;
const openai = new OpenAIApi({
baseURL: "https://api.groq.com/openai/v1",
apiKey,
});
const models = (
await openai.models
.list()
.then((results) => results.data)
.catch((e) => {
console.error(`GroqAi:listModels`, e.message);
return [];
})
).filter(
(model) => !model.id.includes("whisper") && !model.id.includes("tool-use")
);
// Api Key was successful so lets save it for future uses
if (models.length > 0 && !!apiKey) process.env.GROQ_API_KEY = apiKey;
return { models, error: null };
}
async function liteLLMModels(basePath = null, apiKey = null) { async function liteLLMModels(basePath = null, apiKey = null) {
const { OpenAI: OpenAIApi } = require("openai"); const { OpenAI: OpenAIApi } = require("openai");
const openai = new OpenAIApi({ const openai = new OpenAIApi({
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment