From 02ed277dd0ee4239ce044996739fc9093a5caf51 Mon Sep 17 00:00:00 2001
From: Jacopo Zacchigna <59306950+Jac-Zac@users.noreply.github.com>
Date: Wed, 19 Jun 2024 12:43:36 +0200
Subject: [PATCH] Starting to add Groq as a provider (#131)

---------
Co-authored-by: Marcus Schiesser <marcus.schiesser@googlemail.com>
---
 .changeset/pink-terms-cheer.md                |   5 +
 helpers/env-variables.ts                      |   9 ++
 helpers/providers/groq.ts                     | 103 ++++++++++++++++++
 helpers/providers/index.ts                    |   5 +
 helpers/types.ts                              |   2 +-
 .../src/controllers/engine/settings.ts        |  27 +++++
 .../types/streaming/fastapi/app/settings.py   |  38 +++++--
 .../nextjs/app/api/chat/engine/settings.ts    |  25 +++++
 8 files changed, 206 insertions(+), 8 deletions(-)
 create mode 100644 .changeset/pink-terms-cheer.md
 create mode 100644 helpers/providers/groq.ts

diff --git a/.changeset/pink-terms-cheer.md b/.changeset/pink-terms-cheer.md
new file mode 100644
index 00000000..fead10c7
--- /dev/null
+++ b/.changeset/pink-terms-cheer.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add Groq as a model provider
diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts
index 2965381f..a1458ac3 100644
--- a/helpers/env-variables.ts
+++ b/helpers/env-variables.ts
@@ -215,6 +215,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
           },
         ]
       : []),
+    ...(modelConfig.provider === "groq"
+      ? [
+          {
+            name: "GROQ_API_KEY",
+            description: "The Groq API key to use.",
+            value: modelConfig.apiKey,
+          },
+        ]
+      : []),
     ...(modelConfig.provider === "gemini"
       ? [
           {
diff --git a/helpers/providers/groq.ts b/helpers/providers/groq.ts
new file mode 100644
index 00000000..e3289ae0
--- /dev/null
+++ b/helpers/providers/groq.ts
@@ -0,0 +1,103 @@
+import ciInfo from "ci-info";
+import prompts from "prompts";
+import { ModelConfigParams } from ".";
+import { questionHandlers, toChoice } from "../../questions";
+
+const MODELS = [
+  "llama3-8b",
+  "llama3-70b",
+  "mixtral-8x7b",
+];
+const DEFAULT_MODEL = MODELS[0];
+
+// Use huggingface embedding models for now as Groq doesn't support embedding models
+enum HuggingFaceEmbeddingModelType {
+  XENOVA_ALL_MINILM_L6_V2 = "all-MiniLM-L6-v2",
+  XENOVA_ALL_MPNET_BASE_V2 = "all-mpnet-base-v2",
+}
+type ModelData = {
+  dimensions: number;
+};
+const EMBEDDING_MODELS: Record<HuggingFaceEmbeddingModelType, ModelData> = {
+  [HuggingFaceEmbeddingModelType.XENOVA_ALL_MINILM_L6_V2]: {
+    dimensions: 384,
+  },
+  [HuggingFaceEmbeddingModelType.XENOVA_ALL_MPNET_BASE_V2]: {
+    dimensions: 768,
+  },
+};
+const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
+const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
+
+type GroqQuestionsParams = {
+  apiKey?: string;
+  askModels: boolean;
+};
+
+export async function askGroqQuestions({
+  askModels,
+  apiKey,
+}: GroqQuestionsParams): Promise<ModelConfigParams> {
+  const config: ModelConfigParams = {
+    apiKey,
+    model: DEFAULT_MODEL,
+    embeddingModel: DEFAULT_EMBEDDING_MODEL,
+    dimensions: DEFAULT_DIMENSIONS,
+    isConfigured(): boolean {
+      if (config.apiKey) {
+        return true;
+      }
+      if (process.env["GROQ_API_KEY"]) {
+        return true;
+      }
+      return false;
+    },
+  };
+
+  if (!config.apiKey) {
+    const { key } = await prompts(
+      {
+        type: "text",
+        name: "key",
+        message:
+          "Please provide your Groq API key (or leave blank to use GROQ_API_KEY env variable):",
+      },
+      questionHandlers,
+    );
+    config.apiKey = key || process.env.GROQ_API_KEY;
+  }
+
+  // use default model values in CI or if user should not be asked
+  const useDefaults = ciInfo.isCI || !askModels;
+  if (!useDefaults) {
+    const { model } = await prompts(
+      {
+        type: "select",
+        name: "model",
+        message: "Which LLM model would you like to use?",
+        choices: MODELS.map(toChoice),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.model = model;
+
+    const { embeddingModel } = await prompts(
+      {
+        type: "select",
+        name: "embeddingModel",
+        message: "Which embedding model would you like to use?",
+        choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.embeddingModel = embeddingModel;
+    config.dimensions =
+      EMBEDDING_MODELS[
+        embeddingModel as HuggingFaceEmbeddingModelType
+      ].dimensions;
+  }
+
+  return config;
+}
diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts
index 62202c0f..9e4b6147 100644
--- a/helpers/providers/index.ts
+++ b/helpers/providers/index.ts
@@ -3,6 +3,7 @@ import prompts from "prompts";
 import { questionHandlers } from "../../questions";
 import { ModelConfig, ModelProvider } from "../types";
 import { askAnthropicQuestions } from "./anthropic";
+import { askGroqQuestions } from "./groq";
 import { askGeminiQuestions } from "./gemini";
 import { askOllamaQuestions } from "./ollama";
 import { askOpenAIQuestions } from "./openai";
@@ -32,6 +33,7 @@ export async function askModelConfig({
             title: "OpenAI",
             value: "openai",
           },
+          { title: "Groq", value: "groq" },
           { title: "Ollama", value: "ollama" },
           { title: "Anthropic", value: "anthropic" },
           { title: "Gemini", value: "gemini" },
@@ -48,6 +50,9 @@ export async function askModelConfig({
     case "ollama":
       modelConfig = await askOllamaQuestions({ askModels });
       break;
+    case "groq":
+      modelConfig = await askGroqQuestions({ askModels });
+      break;
     case "anthropic":
       modelConfig = await askAnthropicQuestions({ askModels });
       break;
diff --git a/helpers/types.ts b/helpers/types.ts
index b26e7088..5a7bd484 100644
--- a/helpers/types.ts
+++ b/helpers/types.ts
@@ -1,7 +1,7 @@
 import { PackageManager } from "../helpers/get-pkg-manager";
 import { Tool } from "./tools";
 
-export type ModelProvider = "openai" | "ollama" | "anthropic" | "gemini";
+export type ModelProvider = "openai" | "groq" | "ollama" | "anthropic" | "gemini";
 export type ModelConfig = {
   provider: ModelProvider;
   apiKey?: string;
diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts
index d2ccf190..3dd52c04 100644
--- a/templates/types/streaming/express/src/controllers/engine/settings.ts
+++ b/templates/types/streaming/express/src/controllers/engine/settings.ts
@@ -5,12 +5,14 @@ import {
   Gemini,
   GeminiEmbedding,
   OpenAI,
+  Groq,
   OpenAIEmbedding,
   Settings,
 } from "llamaindex";
 import { HuggingFaceEmbedding } from "llamaindex/embeddings/HuggingFaceEmbedding";
 import { OllamaEmbedding } from "llamaindex/embeddings/OllamaEmbedding";
 import { ALL_AVAILABLE_ANTHROPIC_MODELS } from "llamaindex/llm/anthropic";
+import { ALL_AVAILABLE_GROQ_MODELS } from "llamaindex/llm/groq";
 import { Ollama } from "llamaindex/llm/ollama";
 
 const CHUNK_SIZE = 512;
@@ -28,6 +30,9 @@ export const initSettings = async () => {
     case "ollama":
       initOllama();
       break;
+    case "groq":
+      initGroq();
+      break;
     case "anthropic":
       initAnthropic();
       break;
@@ -85,6 +90,28 @@ function initAnthropic() {
   });
 }
 
+
+function initGroq() {
+  const embedModelMap: Record<string, string> = {
+    "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
+    "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
+  };
+
+  const modelMap: Record<string, string> = {
+        "llama3-8b": "llama3-8b-8192",
+        "llama3-70b": "llama3-70b-8192",
+        "mixtral-8x7b": "mixtral-8x7b-32768",
+  }
+
+  Settings.llm = new Groq({
+    model: modelMap[process.env.MODEL!],
+  });
+
+  Settings.embedModel = new HuggingFaceEmbedding({
+    modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
+  });
+}
+
 function initGemini() {
   Settings.llm = new Gemini({
     model: process.env.MODEL as GEMINI_MODEL,
diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py
index 87c591af..c37c13b7 100644
--- a/templates/types/streaming/fastapi/app/settings.py
+++ b/templates/types/streaming/fastapi/app/settings.py
@@ -1,5 +1,6 @@
 import os
 from typing import Dict
+
 from llama_index.core.settings import Settings
 
 
@@ -8,6 +9,8 @@ def init_settings():
     match model_provider:
         case "openai":
             init_openai()
+        case "groq":
+            init_groq()
         case "ollama":
             init_ollama()
         case "anthropic":
@@ -23,8 +26,8 @@ def init_settings():
 
 
 def init_ollama():
-    from llama_index.llms.ollama.base import Ollama, DEFAULT_REQUEST_TIMEOUT
     from llama_index.embeddings.ollama import OllamaEmbedding
+    from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
 
     base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
     request_timeout = float(
@@ -40,9 +43,9 @@ def init_ollama():
 
 
 def init_openai():
-    from llama_index.llms.openai import OpenAI
-    from llama_index.embeddings.openai import OpenAIEmbedding
     from llama_index.core.constants import DEFAULT_TEMPERATURE
+    from llama_index.embeddings.openai import OpenAIEmbedding
+    from llama_index.llms.openai import OpenAI
 
     max_tokens = os.getenv("LLM_MAX_TOKENS")
     config = {
@@ -61,9 +64,9 @@ def init_openai():
 
 
 def init_azure_openai():
-    from llama_index.llms.azure_openai import AzureOpenAI
-    from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
     from llama_index.core.constants import DEFAULT_TEMPERATURE
+    from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
+    from llama_index.llms.azure_openai import AzureOpenAI
 
     llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
     embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
@@ -88,9 +91,30 @@ def init_azure_openai():
     Settings.embed_model = AzureOpenAIEmbedding(**embedding_config)
 
 
+def init_groq():
+    from llama_index.embeddings.huggingface import HuggingFaceEmbedding
+    from llama_index.llms.groq import Groq
+
+    model_map: Dict[str, str] = {
+        "llama3-8b": "llama3-8b-8192",
+        "llama3-70b": "llama3-70b-8192",
+        "mixtral-8x7b": "mixtral-8x7b-32768",
+    }
+
+    embed_model_map: Dict[str, str] = {
+        "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
+        "all-mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
+    }
+
+    Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
+    Settings.embed_model = HuggingFaceEmbedding(
+        model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
+    )
+
+
 def init_anthropic():
-    from llama_index.llms.anthropic import Anthropic
     from llama_index.embeddings.huggingface import HuggingFaceEmbedding
+    from llama_index.llms.anthropic import Anthropic
 
     model_map: Dict[str, str] = {
         "claude-3-opus": "claude-3-opus-20240229",
@@ -112,8 +136,8 @@ def init_anthropic():
 
 
 def init_gemini():
-    from llama_index.llms.gemini import Gemini
     from llama_index.embeddings.gemini import GeminiEmbedding
+    from llama_index.llms.gemini import Gemini
 
     model_name = f"models/{os.getenv('MODEL')}"
     embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
index f8bfd7be..b2fa4d3a 100644
--- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
+++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
@@ -4,6 +4,7 @@ import {
   GEMINI_MODEL,
   Gemini,
   GeminiEmbedding,
+  Groq,
   OpenAI,
   OpenAIEmbedding,
   Settings,
@@ -28,6 +29,9 @@ export const initSettings = async () => {
     case "ollama":
       initOllama();
       break;
+    case "groq":
+      initGroq();
+      break;
     case "anthropic":
       initAnthropic();
       break;
@@ -71,6 +75,27 @@ function initOllama() {
   });
 }
 
+function initGroq() {
+  const embedModelMap: Record<string, string> = {
+    "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
+    "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
+  };
+
+  const modelMap: Record<string, string> = {
+        "llama3-8b": "llama3-8b-8192",
+        "llama3-70b": "llama3-70b-8192",
+        "mixtral-8x7b": "mixtral-8x7b-32768",
+  }
+
+  Settings.llm = new Groq({
+    model: modelMap[process.env.MODEL!],
+  });
+
+  Settings.embedModel = new HuggingFaceEmbedding({
+    modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
+  });
+}
+
 function initAnthropic() {
   const embedModelMap: Record<string, string> = {
     "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
-- 
GitLab