diff --git a/.changeset/pink-terms-cheer.md b/.changeset/pink-terms-cheer.md new file mode 100644 index 0000000000000000000000000000000000000000..fead10c7c4e321fdb06f90c69cd4d44a2560b930 --- /dev/null +++ b/.changeset/pink-terms-cheer.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add Groq as a model provider diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 2965381fef49086bd5b6f8c2b14b91f5822edd08..a1458ac3a5b51046a9d2d9c9b33a183144d3ed99 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -215,6 +215,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => { }, ] : []), + ...(modelConfig.provider === "groq" + ? [ + { + name: "GROQ_API_KEY", + description: "The Groq API key to use.", + value: modelConfig.apiKey, + }, + ] + : []), ...(modelConfig.provider === "gemini" ? [ { diff --git a/helpers/providers/groq.ts b/helpers/providers/groq.ts new file mode 100644 index 0000000000000000000000000000000000000000..e3289ae024794d259211b0897480921f443c2993 --- /dev/null +++ b/helpers/providers/groq.ts @@ -0,0 +1,103 @@ +import ciInfo from "ci-info"; +import prompts from "prompts"; +import { ModelConfigParams } from "."; +import { questionHandlers, toChoice } from "../../questions"; + +const MODELS = [ + "llama3-8b", + "llama3-70b", + "mixtral-8x7b", +]; +const DEFAULT_MODEL = MODELS[0]; + +// Use huggingface embedding models for now as Groq doesn't support embedding models +enum HuggingFaceEmbeddingModelType { + XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2", + XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2", +} +type ModelData = { + dimensions: number; +}; +const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = { + [HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: { + dimensions: 384, + }, + [HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: { + dimensions: 768, + }, +}; +const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0]; +const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions; + +type GroqQuestionsParams = { + apiKey?: string; + askModels: boolean; +}; + +export async function askGroqQuestions({ + askModels, + apiKey, +}: GroqQuestionsParams): Promise<ModelConfigParams> { + const config: ModelConfigParams = { + apiKey, + model: DEFAULT_MODEL, + embeddingModel: DEFAULT_EMBEDDING_MODEL, + dimensions: DEFAULT_DIMENSIONS, + isConfigured(): boolean { + if (config.apiKey) { + return true; + } + if (process.env["GROQ_API_KEY"]) { + return true; + } + return false; + }, + }; + + if (!config.apiKey) { + const { key } = await prompts( + { + type: "text", + name: "key", + message: + "Please provide your Groq API key (or leave blank to use GROQ_API_KEY env variable):", + }, + questionHandlers, + ); + config.apiKey = key || process.env.GROQ_API_KEY; + } + + // use default model values in CI or if user should not be asked + const useDefaults = ciInfo.isCI || !askModels; + if (!useDefaults) { + const { model } = await prompts( + { + type: "select", + name: "model", + message: "Which LLM model would you like to use?", + choices: MODELS.map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.model = model; + + const { embeddingModel } = await prompts( + { + type: "select", + name: "embeddingModel", + message: "Which embedding model would you like to use?", + choices: Object.keys(EMBEDDING_MODELS).map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.embeddingModel = embeddingModel; + config.dimensions = + EMBEDDING_MODELS[ + embeddingModel as HuggingFaceEmbeddingModelType + ].dimensions; + } + + return config; +} diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts index 62202c0f6f9f4aed96de17668f9760ddc260cf15..9e4b61474c6178ff584e13efc28b4a937e001c38 100644 --- a/helpers/providers/index.ts +++ b/helpers/providers/index.ts @@ -3,6 +3,7 @@ import prompts from "prompts"; import { questionHandlers } from "../../questions"; import { ModelConfig, ModelProvider } from "../types"; import { askAnthropicQuestions } from "./anthropic"; +import { askGroqQuestions } from "./groq"; import { askGeminiQuestions } from "./gemini"; import { askOllamaQuestions } from "./ollama"; import { askOpenAIQuestions } from "./openai"; @@ -32,6 +33,7 @@ export async function askModelConfig({ title: "OpenAI", value: "openai", }, + { title: "Groq", value: "groq" }, { title: "Ollama", value: "ollama" }, { title: "Anthropic", value: "anthropic" }, { title: "Gemini", value: "gemini" }, @@ -48,6 +50,9 @@ export async function askModelConfig({ case "ollama": modelConfig = await askOllamaQuestions({ askModels }); break; + case "groq": + modelConfig = await askGroqQuestions({ askModels }); + break; case "anthropic": modelConfig = await askAnthropicQuestions({ askModels }); break; diff --git a/helpers/types.ts b/helpers/types.ts index b26e708864e9af408814d8010b463003f587e261..5a7bd4847b5f3b282287ed02761ebc165165e78e 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -1,7 +1,7 @@ import { PackageManager } from "../helpers/get-pkg-manager"; import { Tool } from "./tools"; -export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini"; +export type ModelProvider = "openai" | "groq" | "ollama" | "anthropic" | "gemini"; export type ModelConfig = { provider: ModelProvider; apiKey?: string; diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts index d2ccf190139c13ab43d6e6abf50f2e2bee279dfc..3dd52c04f42934acec33aff1a1c674dac8f73d00 100644 --- a/templates/types/streaming/express/src/controllers/engine/settings.ts +++ b/templates/types/streaming/express/src/controllers/engine/settings.ts @@ -5,12 +5,14 @@ import { Gemini, GeminiEmbedding, OpenAI, + Groq, OpenAIEmbedding, Settings, } from "llamaindex"; import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding"; import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding"; import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic"; +import { ALL_AVAILABLE_GROQ_MODELS } from "llamaindex/llm/groq"; import { Ollama } from "llamaindex/llm/ollama"; const CHUNK_SIZE = 512; @@ -28,6 +30,9 @@ export const initSettings = async () => { case "ollama": initOllama(); break; + case "groq": + initGroq(); + break; case "anthropic": initAnthropic(); break; @@ -85,6 +90,28 @@ function initAnthropic() { }); } + +function initGroq() { + const embedModelMap: Record<string, string> = { + "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", + }; + + const modelMap: Record<string, string> = { + "llama3-8b": "llama3-8b-8192", + "llama3-70b": "llama3-70b-8192", + "mixtral-8x7b": "mixtral-8x7b-32768", + } + + Settings.llm = new Groq({ + model: modelMap[process.env.MODEL!], + }); + + Settings.embedModel = new HuggingFaceEmbedding({ + modelType: embedModelMap[process.env.EMBEDDING_MODEL!], + }); +} + function initGemini() { Settings.llm = new Gemini({ model: process.env.MODEL as GEMINI_MODEL, diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py index 87c591af46e5eacf5360fe98c3eb2fa90f95ccb3..c37c13b7ba7e103a48ccda51ccf7ea65e950c65f 100644 --- a/templates/types/streaming/fastapi/app/settings.py +++ b/templates/types/streaming/fastapi/app/settings.py @@ -1,5 +1,6 @@ import os from typing import Dict + from llama_index.core.settings import Settings @@ -8,6 +9,8 @@ def init_settings(): match model_provider: case "openai": init_openai() + case "groq": + init_groq() case "ollama": init_ollama() case "anthropic": @@ -23,8 +26,8 @@ def init_settings(): def init_ollama(): - from llama_index.llms.ollama.base import Ollama, DEFAULT_REQUEST_TIMEOUT from llama_index.embeddings.ollama import OllamaEmbedding + from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434" request_timeout = float( @@ -40,9 +43,9 @@ def init_ollama(): def init_openai(): - from llama_index.llms.openai import OpenAI - from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.core.constants import DEFAULT_TEMPERATURE + from llama_index.embeddings.openai import OpenAIEmbedding + from llama_index.llms.openai import OpenAI max_tokens = os.getenv("LLM_MAX_TOKENS") config = { @@ -61,9 +64,9 @@ def init_openai(): def init_azure_openai(): - from llama_index.llms.azure_openai import AzureOpenAI - from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from llama_index.core.constants import DEFAULT_TEMPERATURE + from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding + from llama_index.llms.azure_openai import AzureOpenAI llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") @@ -88,9 +91,30 @@ def init_azure_openai(): Settings.embed_model = AzureOpenAIEmbedding(**embedding_config) +def init_groq(): + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + from llama_index.llms.groq import Groq + + model_map: Dict[str, str] = { + "llama3-8b": "llama3-8b-8192", + "llama3-70b": "llama3-70b-8192", + "mixtral-8x7b": "mixtral-8x7b-32768", + } + + embed_model_map: Dict[str, str] = { + "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2", + } + + Settings.llm = Groq(model=model_map[os.getenv("MODEL")]) + Settings.embed_model = HuggingFaceEmbedding( + model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] + ) + + def init_anthropic(): - from llama_index.llms.anthropic import Anthropic from llama_index.embeddings.huggingface import HuggingFaceEmbedding + from llama_index.llms.anthropic import Anthropic model_map: Dict[str, str] = { "claude-3-opus": "claude-3-opus-20240229", @@ -112,8 +136,8 @@ def init_anthropic(): def init_gemini(): - from llama_index.llms.gemini import Gemini from llama_index.embeddings.gemini import GeminiEmbedding + from llama_index.llms.gemini import Gemini model_name = f"models/{os.getenv('MODEL')}" embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}" diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts index f8bfd7bec219b6d0e8720128809acbb171131ed1..b2fa4d3a7b920c743890d42471708da869b66339 100644 --- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts +++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts @@ -4,6 +4,7 @@ import { GEMINI_MODEL, Gemini, GeminiEmbedding, + Groq, OpenAI, OpenAIEmbedding, Settings, @@ -28,6 +29,9 @@ export const initSettings = async () => { case "ollama": initOllama(); break; + case "groq": + initGroq(); + break; case "anthropic": initAnthropic(); break; @@ -71,6 +75,27 @@ function initOllama() { }); } +function initGroq() { + const embedModelMap: Record<string, string> = { + "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", + }; + + const modelMap: Record<string, string> = { + "llama3-8b": "llama3-8b-8192", + "llama3-70b": "llama3-70b-8192", + "mixtral-8x7b": "mixtral-8x7b-32768", + } + + Settings.llm = new Groq({ + model: modelMap[process.env.MODEL!], + }); + + Settings.embedModel = new HuggingFaceEmbedding({ + modelType: embedModelMap[process.env.EMBEDDING_MODEL!], + }); +} + function initAnthropic() { const embedModelMap: Record<string, string> = { "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",