From 02ed277dd0ee4239ce044996739fc9093a5caf51 Mon Sep 17 00:00:00 2001 From: Jacopo Zacchigna <59306950+Jac-Zac@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:43:36 +0200 Subject: [PATCH] Starting to add Groq as a provider (#131) --------- Co-authored-by: Marcus Schiesser <marcus.schiesser@googlemail.com> --- .changeset/pink-terms-cheer.md | 5 + helpers/env-variables.ts | 9 ++ helpers/providers/groq.ts | 103 ++++++++++++++++++ helpers/providers/index.ts | 5 + helpers/types.ts | 2 +- .../src/controllers/engine/settings.ts | 27 +++++ .../types/streaming/fastapi/app/settings.py | 38 +++++-- .../nextjs/app/api/chat/engine/settings.ts | 25 +++++ 8 files changed, 206 insertions(+), 8 deletions(-) create mode 100644 .changeset/pink-terms-cheer.md create mode 100644 helpers/providers/groq.ts diff --git a/.changeset/pink-terms-cheer.md b/.changeset/pink-terms-cheer.md new file mode 100644 index 00000000..fead10c7 --- /dev/null +++ b/.changeset/pink-terms-cheer.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Add Groq as a model provider diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 2965381f..a1458ac3 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -215,6 +215,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => { }, ] : []), + ...(modelConfig.provider === "groq" + ? [ + { + name: "GROQ_API_KEY", + description: "The Groq API key to use.", + value: modelConfig.apiKey, + }, + ] + : []), ...(modelConfig.provider === "gemini" ? [ { diff --git a/helpers/providers/groq.ts b/helpers/providers/groq.ts new file mode 100644 index 00000000..e3289ae0 --- /dev/null +++ b/helpers/providers/groq.ts @@ -0,0 +1,103 @@ +import ciInfo from "ci-info"; +import prompts from "prompts"; +import { ModelConfigParams } from "."; +import { questionHandlers, toChoice } from "../../questions"; + +const MODELS = [ + "llama3-8b", + "llama3-70b", + "mixtral-8x7b", +]; +const DEFAULT_MODEL = MODELS[0]; + +// Use huggingface embedding models for now as Groq doesn't support embedding models +enum HuggingFaceEmbeddingModelType { + XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2", + XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2", +} +type ModelData = { + dimensions: number; +}; +const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = { + [HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: { + dimensions: 384, + }, + [HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: { + dimensions: 768, + }, +}; +const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0]; +const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions; + +type GroqQuestionsParams = { + apiKey?: string; + askModels: boolean; +}; + +export async function askGroqQuestions({ + askModels, + apiKey, +}: GroqQuestionsParams): Promise<ModelConfigParams> { + const config: ModelConfigParams = { + apiKey, + model: DEFAULT_MODEL, + embeddingModel: DEFAULT_EMBEDDING_MODEL, + dimensions: DEFAULT_DIMENSIONS, + isConfigured(): boolean { + if (config.apiKey) { + return true; + } + if (process.env["GROQ_API_KEY"]) { + return true; + } + return false; + }, + }; + + if (!config.apiKey) { + const { key } = await prompts( + { + type: "text", + name: "key", + message: + "Please provide your Groq API key (or leave blank to use GROQ_API_KEY env variable):", + }, + questionHandlers, + ); + config.apiKey = key || process.env.GROQ_API_KEY; + } + + // use default model values in CI or if user should not be asked + const useDefaults = ciInfo.isCI || !askModels; + if (!useDefaults) { + const { model } = await prompts( + { + type: "select", + name: "model", + message: "Which LLM model would you like to use?", + choices: MODELS.map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.model = model; + + const { embeddingModel } = await prompts( + { + type: "select", + name: "embeddingModel", + message: "Which embedding model would you like to use?", + choices: Object.keys(EMBEDDING_MODELS).map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.embeddingModel = embeddingModel; + config.dimensions = + EMBEDDING_MODELS[ + embeddingModel as HuggingFaceEmbeddingModelType + ].dimensions; + } + + return config; +} diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts index 62202c0f..9e4b6147 100644 --- a/helpers/providers/index.ts +++ b/helpers/providers/index.ts @@ -3,6 +3,7 @@ import prompts from "prompts"; import { questionHandlers } from "../../questions"; import { ModelConfig, ModelProvider } from "../types"; import { askAnthropicQuestions } from "./anthropic"; +import { askGroqQuestions } from "./groq"; import { askGeminiQuestions } from "./gemini"; import { askOllamaQuestions } from "./ollama"; import { askOpenAIQuestions } from "./openai"; @@ -32,6 +33,7 @@ export async function askModelConfig({ title: "OpenAI", value: "openai", }, + { title: "Groq", value: "groq" }, { title: "Ollama", value: "ollama" }, { title: "Anthropic", value: "anthropic" }, { title: "Gemini", value: "gemini" }, @@ -48,6 +50,9 @@ export async function askModelConfig({ case "ollama": modelConfig = await askOllamaQuestions({ askModels }); break; + case "groq": + modelConfig = await askGroqQuestions({ askModels }); + break; case "anthropic": modelConfig = await askAnthropicQuestions({ askModels }); break; diff --git a/helpers/types.ts b/helpers/types.ts index b26e7088..5a7bd484 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -1,7 +1,7 @@ import { PackageManager } from "../helpers/get-pkg-manager"; import { Tool } from "./tools"; -export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini"; +export type ModelProvider = "openai" | "groq" | "ollama" | "anthropic" | "gemini"; export type ModelConfig = { provider: ModelProvider; apiKey?: string; diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts index d2ccf190..3dd52c04 100644 --- a/templates/types/streaming/express/src/controllers/engine/settings.ts +++ b/templates/types/streaming/express/src/controllers/engine/settings.ts @@ -5,12 +5,14 @@ import { Gemini, GeminiEmbedding, OpenAI, + Groq, OpenAIEmbedding, Settings, } from "llamaindex"; import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding"; import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding"; import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic"; +import { ALL_AVAILABLE_GROQ_MODELS } from "llamaindex/llm/groq"; import { Ollama } from "llamaindex/llm/ollama"; const CHUNK_SIZE = 512; @@ -28,6 +30,9 @@ export const initSettings = async () => { case "ollama": initOllama(); break; + case "groq": + initGroq(); + break; case "anthropic": initAnthropic(); break; @@ -85,6 +90,28 @@ function initAnthropic() { }); } + +function initGroq() { + const embedModelMap: Record<string, string> = { + "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", + }; + + const modelMap: Record<string, string> = { + "llama3-8b": "llama3-8b-8192", + "llama3-70b": "llama3-70b-8192", + "mixtral-8x7b": "mixtral-8x7b-32768", + } + + Settings.llm = new Groq({ + model: modelMap[process.env.MODEL!], + }); + + Settings.embedModel = new HuggingFaceEmbedding({ + modelType: embedModelMap[process.env.EMBEDDING_MODEL!], + }); +} + function initGemini() { Settings.llm = new Gemini({ model: process.env.MODEL as GEMINI_MODEL, diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py index 87c591af..c37c13b7 100644 --- a/templates/types/streaming/fastapi/app/settings.py +++ b/templates/types/streaming/fastapi/app/settings.py @@ -1,5 +1,6 @@ import os from typing import Dict + from llama_index.core.settings import Settings @@ -8,6 +9,8 @@ def init_settings(): match model_provider: case "openai": init_openai() + case "groq": + init_groq() case "ollama": init_ollama() case "anthropic": @@ -23,8 +26,8 @@ def init_settings(): def init_ollama(): - from llama_index.llms.ollama.base import Ollama, DEFAULT_REQUEST_TIMEOUT from llama_index.embeddings.ollama import OllamaEmbedding + from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434" request_timeout = float( @@ -40,9 +43,9 @@ def init_ollama(): def init_openai(): - from llama_index.llms.openai import OpenAI - from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.core.constants import DEFAULT_TEMPERATURE + from llama_index.embeddings.openai import OpenAIEmbedding + from llama_index.llms.openai import OpenAI max_tokens = os.getenv("LLM_MAX_TOKENS") config = { @@ -61,9 +64,9 @@ def init_openai(): def init_azure_openai(): - from llama_index.llms.azure_openai import AzureOpenAI - from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from llama_index.core.constants import DEFAULT_TEMPERATURE + from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding + from llama_index.llms.azure_openai import AzureOpenAI llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") @@ -88,9 +91,30 @@ def init_azure_openai(): Settings.embed_model = AzureOpenAIEmbedding(**embedding_config) +def init_groq(): + from llama_index.embeddings.huggingface import HuggingFaceEmbedding + from llama_index.llms.groq import Groq + + model_map: Dict[str, str] = { + "llama3-8b": "llama3-8b-8192", + "llama3-70b": "llama3-70b-8192", + "mixtral-8x7b": "mixtral-8x7b-32768", + } + + embed_model_map: Dict[str, str] = { + "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2", + } + + Settings.llm = Groq(model=model_map[os.getenv("MODEL")]) + Settings.embed_model = HuggingFaceEmbedding( + model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] + ) + + def init_anthropic(): - from llama_index.llms.anthropic import Anthropic from llama_index.embeddings.huggingface import HuggingFaceEmbedding + from llama_index.llms.anthropic import Anthropic model_map: Dict[str, str] = { "claude-3-opus": "claude-3-opus-20240229", @@ -112,8 +136,8 @@ def init_anthropic(): def init_gemini(): - from llama_index.llms.gemini import Gemini from llama_index.embeddings.gemini import GeminiEmbedding + from llama_index.llms.gemini import Gemini model_name = f"models/{os.getenv('MODEL')}" embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}" diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts index f8bfd7be..b2fa4d3a 100644 --- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts +++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts @@ -4,6 +4,7 @@ import { GEMINI_MODEL, Gemini, GeminiEmbedding, + Groq, OpenAI, OpenAIEmbedding, Settings, @@ -28,6 +29,9 @@ export const initSettings = async () => { case "ollama": initOllama(); break; + case "groq": + initGroq(); + break; case "anthropic": initAnthropic(); break; @@ -71,6 +75,27 @@ function initOllama() { }); } +function initGroq() { + const embedModelMap: Record<string, string> = { + "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", + }; + + const modelMap: Record<string, string> = { + "llama3-8b": "llama3-8b-8192", + "llama3-70b": "llama3-70b-8192", + "mixtral-8x7b": "mixtral-8x7b-32768", + } + + Settings.llm = new Groq({ + model: modelMap[process.env.MODEL!], + }); + + Settings.embedModel = new HuggingFaceEmbedding({ + modelType: embedModelMap[process.env.EMBEDDING_MODEL!], + }); +} + function initAnthropic() { const embedModelMap: Record<string, string> = { "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", -- GitLab