diff --git a/.changeset/real-zebras-attack.md b/.changeset/real-zebras-attack.md new file mode 100644 index 0000000000000000000000000000000000000000..4ec6b66da314a0d17aec531fb2a23d1d284ce7d8 --- /dev/null +++ b/.changeset/real-zebras-attack.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Dynamically select model for Groq diff --git a/helpers/providers/groq.ts b/helpers/providers/groq.ts index f1f2905166fedf28dc6365c84826859f19281cf7..075475e3998d7199a6f98bb4d874313041d28998 100644 --- a/helpers/providers/groq.ts +++ b/helpers/providers/groq.ts @@ -3,8 +3,55 @@ import prompts from "prompts"; import { ModelConfigParams } from "."; import { questionHandlers, toChoice } from "../../questions"; -const MODELS = ["llama3-8b", "llama3-70b", "mixtral-8x7b"]; -const DEFAULT_MODEL = MODELS[0]; +import got from "got"; +import ora from "ora"; +import { red } from "picocolors"; + +const GROQ_API_URL = "https://api.groq.com/openai/v1"; + +async function getAvailableModelChoicesGroq(apiKey: string) { + if (!apiKey) { + throw new Error("Need Groq API key to retrieve model choices"); + } + + const spinner = ora("Fetching available models from Groq").start(); + try { + const response = await got(`${GROQ_API_URL}/models`, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + timeout: 5000, + responseType: "json", + }); + const data: any = await response.body; + spinner.stop(); + + // Filter out the Whisper models + return data.data + .filter((model: any) => !model.id.toLowerCase().includes("whisper")) + .map((el: any) => { + return { + title: el.id, + value: el.id, + }; + }); + } catch (error: unknown) { + spinner.stop(); + console.log(error); + if ((error as any).response?.statusCode === 401) { + console.log( + red( + "Invalid Groq API key provided! Please provide a valid key and try again!", + ), + ); + } else { + console.log(red("Request failed: " + error)); + } + process.exit(1); + } +} + +const DEFAULT_MODEL = "llama3-70b-8192"; // Use huggingface embedding models for now as Groq doesn't support embedding models enum HuggingFaceEmbeddingModelType { @@ -66,12 +113,14 @@ export async function askGroqQuestions({ // use default model values in CI or if user should not be asked const useDefaults = ciInfo.isCI || !askModels; if (!useDefaults) { + const modelChoices = await getAvailableModelChoicesGroq(config.apiKey!); + const { model } = await prompts( { type: "select", name: "model", message: "Which LLM model would you like to use?", - choices: MODELS.map(toChoice), + choices: modelChoices, initial: 0, }, questionHandlers, diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py index bb8287059c948b00b5b3af5e812794872795d9b9..620a437938153688c1fed0da46f97fbce08eb167 100644 --- a/templates/components/settings/python/settings.py +++ b/templates/components/settings/python/settings.py @@ -126,13 +126,7 @@ def init_fastembed(): def init_groq(): from llama_index.llms.groq import Groq - model_map: Dict[str, str] = { - "llama3-8b": "llama3-8b-8192", - "llama3-70b": "llama3-70b-8192", - "mixtral-8x7b": "mixtral-8x7b-32768", - } - - Settings.llm = Groq(model=model_map[os.getenv("MODEL")]) + Settings.llm = Groq(model=os.getenv("MODEL")) # Groq does not provide embeddings, so we use FastEmbed instead init_fastembed() diff --git a/templates/types/streaming/express/package.json b/templates/types/streaming/express/package.json index 1dec40f4de46699590f5cd1e9e212a01746720d8..e860060f2b324ffc4ac2712cf49d2b816374c387 100644 --- a/templates/types/streaming/express/package.json +++ b/templates/types/streaming/express/package.json @@ -20,7 +20,7 @@ "dotenv": "^16.3.1", "duck-duck-scrape": "^2.2.5", "express": "^4.18.2", - "llamaindex": "0.5.24", + "llamaindex": "0.6.2", "pdf2json": "3.0.5", "ajv": "^8.12.0", "@e2b/code-interpreter": "^0.0.5", diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts index 28761d26fdc95d19f32d56e0c09c8af113e5fe81..5691fe7730a9cc02c1e067deb12738bfb1d3b360 100644 --- a/templates/types/streaming/express/src/controllers/engine/settings.ts +++ b/templates/types/streaming/express/src/controllers/engine/settings.ts @@ -138,14 +138,8 @@ function initGroq() { "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", }; - const modelMap: Record<string, string> = { - "llama3-8b": "llama3-8b-8192", - "llama3-70b": "llama3-70b-8192", - "mixtral-8x7b": "mixtral-8x7b-32768", - }; - Settings.llm = new Groq({ - model: modelMap[process.env.MODEL!], + model: process.env.MODEL!, }); Settings.embedModel = new HuggingFaceEmbedding({ diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts index 28761d26fdc95d19f32d56e0c09c8af113e5fe81..5691fe7730a9cc02c1e067deb12738bfb1d3b360 100644 --- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts +++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts @@ -138,14 +138,8 @@ function initGroq() { "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", }; - const modelMap: Record<string, string> = { - "llama3-8b": "llama3-8b-8192", - "llama3-70b": "llama3-70b-8192", - "mixtral-8x7b": "mixtral-8x7b-32768", - }; - Settings.llm = new Groq({ - model: modelMap[process.env.MODEL!], + model: process.env.MODEL!, }); Settings.embedModel = new HuggingFaceEmbedding({ diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 26df8b49b0a48d0589696d41eedbb6d665fe1d8b..b244f641a1b6839653641a502eab672a6d02ed0d 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -25,7 +25,7 @@ "duck-duck-scrape": "^2.2.5", "formdata-node": "^6.0.3", "got": "^14.4.1", - "llamaindex": "0.5.24", + "llamaindex": "0.6.2", "lucide-react": "^0.294.0", "next": "^14.2.4", "react": "^18.2.0",