From d1026ea7842ff419fbc6c581ad6b3dc0e31aa8f8 Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Fri, 5 Jul 2024 16:58:10 +0700
Subject: [PATCH] feat: support mistral as llm and embedding (#155)

---
 .changeset/chilly-worms-speak.md              |  5 ++
 helpers/env-variables.ts                      |  9 ++
 helpers/providers/index.ts                    |  5 ++
 helpers/providers/mistral.ts                  | 86 +++++++++++++++++++
 helpers/python.ts                             | 10 +++
 helpers/types.ts                              |  1 +
 .../components/settings/python/settings.py    | 10 +++
 .../src/controllers/engine/settings.ts        | 43 +++++++---
 .../nextjs/app/api/chat/engine/settings.ts    | 16 ++++
 9 files changed, 171 insertions(+), 14 deletions(-)
 create mode 100644 .changeset/chilly-worms-speak.md
 create mode 100644 helpers/providers/mistral.ts

diff --git a/.changeset/chilly-worms-speak.md b/.changeset/chilly-worms-speak.md
new file mode 100644
index 00000000..45dd1eb5
--- /dev/null
+++ b/.changeset/chilly-worms-speak.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+support Mistral as llm and embedding
diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts
index b6f7af86..7b10e23d 100644
--- a/helpers/env-variables.ts
+++ b/helpers/env-variables.ts
@@ -265,6 +265,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
           },
         ]
       : []),
+    ...(modelConfig.provider === "mistral"
+      ? [
+          {
+            name: "MISTRAL_API_KEY",
+            description: "The Mistral API key to use.",
+            value: modelConfig.apiKey,
+          },
+        ]
+      : []),
     ...(modelConfig.provider === "t-systems"
       ? [
           {
diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts
index 22e4de3e..b660248c 100644
--- a/helpers/providers/index.ts
+++ b/helpers/providers/index.ts
@@ -6,6 +6,7 @@ import { askAnthropicQuestions } from "./anthropic";
 import { askGeminiQuestions } from "./gemini";
 import { askGroqQuestions } from "./groq";
 import { askLLMHubQuestions } from "./llmhub";
+import { askMistralQuestions } from "./mistral";
 import { askOllamaQuestions } from "./ollama";
 import { askOpenAIQuestions } from "./openai";
 
@@ -32,6 +33,7 @@ export async function askModelConfig({
       { title: "Ollama", value: "ollama" },
       { title: "Anthropic", value: "anthropic" },
       { title: "Gemini", value: "gemini" },
+      { title: "Mistral", value: "mistral" },
     ];
 
     if (framework === "fastapi") {
@@ -64,6 +66,9 @@ export async function askModelConfig({
     case "gemini":
       modelConfig = await askGeminiQuestions({ askModels });
       break;
+    case "mistral":
+      modelConfig = await askMistralQuestions({ askModels });
+      break;
     case "t-systems":
       modelConfig = await askLLMHubQuestions({ askModels });
       break;
diff --git a/helpers/providers/mistral.ts b/helpers/providers/mistral.ts
new file mode 100644
index 00000000..b892b748
--- /dev/null
+++ b/helpers/providers/mistral.ts
@@ -0,0 +1,86 @@
+import ciInfo from "ci-info";
+import prompts from "prompts";
+import { ModelConfigParams } from ".";
+import { questionHandlers, toChoice } from "../../questions";
+
+const MODELS = ["mistral-tiny", "mistral-small", "mistral-medium"];
+type ModelData = {
+  dimensions: number;
+};
+const EMBEDDING_MODELS: Record<string, ModelData> = {
+  "mistral-embed": { dimensions: 1024 },
+};
+
+const DEFAULT_MODEL = MODELS[0];
+const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
+const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
+
+type MistralQuestionsParams = {
+  apiKey?: string;
+  askModels: boolean;
+};
+
+export async function askMistralQuestions({
+  askModels,
+  apiKey,
+}: MistralQuestionsParams): Promise<ModelConfigParams> {
+  const config: ModelConfigParams = {
+    apiKey,
+    model: DEFAULT_MODEL,
+    embeddingModel: DEFAULT_EMBEDDING_MODEL,
+    dimensions: DEFAULT_DIMENSIONS,
+    isConfigured(): boolean {
+      if (config.apiKey) {
+        return true;
+      }
+      if (process.env["MISTRAL_API_KEY"]) {
+        return true;
+      }
+      return false;
+    },
+  };
+
+  if (!config.apiKey) {
+    const { key } = await prompts(
+      {
+        type: "text",
+        name: "key",
+        message:
+          "Please provide your Mistral API key (or leave blank to use MISTRAL_API_KEY env variable):",
+      },
+      questionHandlers,
+    );
+    config.apiKey = key || process.env.MISTRAL_API_KEY;
+  }
+
+  // use default model values in CI or if user should not be asked
+  const useDefaults = ciInfo.isCI || !askModels;
+  if (!useDefaults) {
+    const { model } = await prompts(
+      {
+        type: "select",
+        name: "model",
+        message: "Which LLM model would you like to use?",
+        choices: MODELS.map(toChoice),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.model = model;
+
+    const { embeddingModel } = await prompts(
+      {
+        type: "select",
+        name: "embeddingModel",
+        message: "Which embedding model would you like to use?",
+        choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.embeddingModel = embeddingModel;
+    config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
+  }
+
+  return config;
+}
diff --git a/helpers/python.ts b/helpers/python.ts
index 04f6a11a..186017be 100644
--- a/helpers/python.ts
+++ b/helpers/python.ts
@@ -173,6 +173,16 @@ const getAdditionalDependencies = (
         version: "0.1.6",
       });
       break;
+    case "mistral":
+      dependencies.push({
+        name: "llama-index-llms-mistralai",
+        version: "0.1.17",
+      });
+      dependencies.push({
+        name: "llama-index-embeddings-mistralai",
+        version: "0.1.4",
+      });
+      break;
     case "t-systems":
       dependencies.push({
         name: "llama-index-agent-openai",
diff --git a/helpers/types.ts b/helpers/types.ts
index 12c62019..ee0ee853 100644
--- a/helpers/types.ts
+++ b/helpers/types.ts
@@ -7,6 +7,7 @@ export type ModelProvider =
   | "ollama"
   | "anthropic"
   | "gemini"
+  | "mistral"
   | "t-systems";
 export type ModelConfig = {
   provider: ModelProvider;
diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py
index 2f7e4b3a..e0c974cc 100644
--- a/templates/components/settings/python/settings.py
+++ b/templates/components/settings/python/settings.py
@@ -17,6 +17,8 @@ def init_settings():
             init_anthropic()
         case "gemini":
             init_gemini()
+        case "mistral":
+            init_mistral()
         case "azure-openai":
             init_azure_openai()
         case "t-systems":
@@ -149,3 +151,11 @@ def init_gemini():
 
     Settings.llm = Gemini(model=model_name)
     Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
+
+
+def init_mistral():
+    from llama_index.embeddings.mistralai import MistralAIEmbedding
+    from llama_index.llms.mistralai import MistralAI
+
+    Settings.llm = MistralAI(model=os.getenv("MODEL"))
+    Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts
index afc4f6b8..98160a56 100644
--- a/templates/types/streaming/express/src/controllers/engine/settings.ts
+++ b/templates/types/streaming/express/src/controllers/engine/settings.ts
@@ -1,10 +1,14 @@
 import {
+  ALL_AVAILABLE_MISTRAL_MODELS,
   Anthropic,
   GEMINI_EMBEDDING_MODEL,
   GEMINI_MODEL,
   Gemini,
   GeminiEmbedding,
   Groq,
+  MistralAI,
+  MistralAIEmbedding,
+  MistralAIEmbeddingModelType,
   OpenAI,
   OpenAIEmbedding,
   Settings,
@@ -38,6 +42,9 @@ export const initSettings = async () => {
     case "gemini":
       initGemini();
       break;
+    case "mistral":
+      initMistralAI();
+      break;
     default:
       initOpenAI();
       break;
@@ -65,7 +72,6 @@ function initOllama() {
   const config = {
     host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
   };
-
   Settings.llm = new Ollama({
     model: process.env.MODEL ?? "",
     config,
@@ -76,19 +82,6 @@ function initOllama() {
   });
 }
 
-function initAnthropic() {
-  const embedModelMap: Record<string, string> = {
-    "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
-    "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
-  };
-  Settings.llm = new Anthropic({
-    model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
-  });
-  Settings.embedModel = new HuggingFaceEmbedding({
-    modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
-  });
-}
-
 function initGroq() {
   const embedModelMap: Record<string, string> = {
     "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
@@ -110,6 +103,19 @@ function initGroq() {
   });
 }
 
+function initAnthropic() {
+  const embedModelMap: Record<string, string> = {
+    "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
+    "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
+  };
+  Settings.llm = new Anthropic({
+    model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
+  });
+  Settings.embedModel = new HuggingFaceEmbedding({
+    modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
+  });
+}
+
 function initGemini() {
   Settings.llm = new Gemini({
     model: process.env.MODEL as GEMINI_MODEL,
@@ -118,3 +124,12 @@ function initGemini() {
     model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
   });
 }
+
+function initMistralAI() {
+  Settings.llm = new MistralAI({
+    model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
+  });
+  Settings.embedModel = new MistralAIEmbedding({
+    model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
+  });
+}
diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
index 38032389..98160a56 100644
--- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
+++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
@@ -1,10 +1,14 @@
 import {
+  ALL_AVAILABLE_MISTRAL_MODELS,
   Anthropic,
   GEMINI_EMBEDDING_MODEL,
   GEMINI_MODEL,
   Gemini,
   GeminiEmbedding,
   Groq,
+  MistralAI,
+  MistralAIEmbedding,
+  MistralAIEmbeddingModelType,
   OpenAI,
   OpenAIEmbedding,
   Settings,
@@ -38,6 +42,9 @@ export const initSettings = async () => {
     case "gemini":
       initGemini();
       break;
+    case "mistral":
+      initMistralAI();
+      break;
     default:
       initOpenAI();
       break;
@@ -117,3 +124,12 @@ function initGemini() {
     model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
   });
 }
+
+function initMistralAI() {
+  Settings.llm = new MistralAI({
+    model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
+  });
+  Settings.embedModel = new MistralAIEmbedding({
+    model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
+  });
+}
-- 
GitLab