Skip to content
Snippets Groups Projects
Unverified Commit cf3ec97a authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

Dynamically select model for Groq (#278)


---------
Co-authored-by: default avatarJac-Zac <jacopozac@icloud.com>
Co-authored-by: default avatarThuc Pham <51660321+thucpn@users.noreply.github.com>
parent 505b8e94
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Dynamically select model for Groq
...@@ -3,8 +3,55 @@ import prompts from "prompts"; ...@@ -3,8 +3,55 @@ import prompts from "prompts";
import { ModelConfigParams } from "."; import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions"; import { questionHandlers, toChoice } from "../../questions";
const MODELS = ["llama3-8b", "llama3-70b", "mixtral-8x7b"]; import got from "got";
const DEFAULT_MODEL = MODELS[0]; import ora from "ora";
import { red } from "picocolors";
const GROQ_API_URL = "https://api.groq.com/openai/v1";
async function getAvailableModelChoicesGroq(apiKey: string) {
if (!apiKey) {
throw new Error("Need Groq API key to retrieve model choices");
}
const spinner = ora("Fetching available models from Groq").start();
try {
const response = await got(`${GROQ_API_URL}/models`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
timeout: 5000,
responseType: "json",
});
const data: any = await response.body;
spinner.stop();
// Filter out the Whisper models
return data.data
.filter((model: any) => !model.id.toLowerCase().includes("whisper"))
.map((el: any) => {
return {
title: el.id,
value: el.id,
};
});
} catch (error: unknown) {
spinner.stop();
console.log(error);
if ((error as any).response?.statusCode === 401) {
console.log(
red(
"Invalid Groq API key provided! Please provide a valid key and try again!",
),
);
} else {
console.log(red("Request failed: " + error));
}
process.exit(1);
}
}
const DEFAULT_MODEL = "llama3-70b-8192";
// Use huggingface embedding models for now as Groq doesn't support embedding models // Use huggingface embedding models for now as Groq doesn't support embedding models
enum HuggingFaceEmbeddingModelType { enum HuggingFaceEmbeddingModelType {
...@@ -66,12 +113,14 @@ export async function askGroqQuestions({ ...@@ -66,12 +113,14 @@ export async function askGroqQuestions({
// use default model values in CI or if user should not be asked // use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels; const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) { if (!useDefaults) {
const modelChoices = await getAvailableModelChoicesGroq(config.apiKey!);
const { model } = await prompts( const { model } = await prompts(
{ {
type: "select", type: "select",
name: "model", name: "model",
message: "Which LLM model would you like to use?", message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice), choices: modelChoices,
initial: 0, initial: 0,
}, },
questionHandlers, questionHandlers,
......
...@@ -126,13 +126,7 @@ def init_fastembed(): ...@@ -126,13 +126,7 @@ def init_fastembed():
def init_groq(): def init_groq():
from llama_index.llms.groq import Groq from llama_index.llms.groq import Groq
model_map: Dict[str, str] = { Settings.llm = Groq(model=os.getenv("MODEL"))
"llama3-8b": "llama3-8b-8192",
"llama3-70b": "llama3-70b-8192",
"mixtral-8x7b": "mixtral-8x7b-32768",
}
Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
# Groq does not provide embeddings, so we use FastEmbed instead # Groq does not provide embeddings, so we use FastEmbed instead
init_fastembed() init_fastembed()
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
"dotenv": "^16.3.1", "dotenv": "^16.3.1",
"duck-duck-scrape": "^2.2.5", "duck-duck-scrape": "^2.2.5",
"express": "^4.18.2", "express": "^4.18.2",
"llamaindex": "0.5.24", "llamaindex": "0.6.2",
"pdf2json": "3.0.5", "pdf2json": "3.0.5",
"ajv": "^8.12.0", "ajv": "^8.12.0",
"@e2b/code-interpreter": "^0.0.5", "@e2b/code-interpreter": "^0.0.5",
......
...@@ -138,14 +138,8 @@ function initGroq() { ...@@ -138,14 +138,8 @@ function initGroq() {
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
}; };
const modelMap: Record<string, string> = {
"llama3-8b": "llama3-8b-8192",
"llama3-70b": "llama3-70b-8192",
"mixtral-8x7b": "mixtral-8x7b-32768",
};
Settings.llm = new Groq({ Settings.llm = new Groq({
model: modelMap[process.env.MODEL!], model: process.env.MODEL!,
}); });
Settings.embedModel = new HuggingFaceEmbedding({ Settings.embedModel = new HuggingFaceEmbedding({
......
...@@ -138,14 +138,8 @@ function initGroq() { ...@@ -138,14 +138,8 @@ function initGroq() {
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
}; };
const modelMap: Record<string, string> = {
"llama3-8b": "llama3-8b-8192",
"llama3-70b": "llama3-70b-8192",
"mixtral-8x7b": "mixtral-8x7b-32768",
};
Settings.llm = new Groq({ Settings.llm = new Groq({
model: modelMap[process.env.MODEL!], model: process.env.MODEL!,
}); });
Settings.embedModel = new HuggingFaceEmbedding({ Settings.embedModel = new HuggingFaceEmbedding({
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
"duck-duck-scrape": "^2.2.5", "duck-duck-scrape": "^2.2.5",
"formdata-node": "^6.0.3", "formdata-node": "^6.0.3",
"got": "^14.4.1", "got": "^14.4.1",
"llamaindex": "0.5.24", "llamaindex": "0.6.2",
"lucide-react": "^0.294.0", "lucide-react": "^0.294.0",
"next": "^14.2.4", "next": "^14.2.4",
"react": "^18.2.0", "react": "^18.2.0",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment