From d1026ea7842ff419fbc6c581ad6b3dc0e31aa8f8 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Fri, 5 Jul 2024 16:58:10 +0700 Subject: [PATCH] feat: support mistral as llm and embedding (#155) --- .changeset/chilly-worms-speak.md | 5 ++ helpers/env-variables.ts | 9 ++ helpers/providers/index.ts | 5 ++ helpers/providers/mistral.ts | 86 +++++++++++++++++++ helpers/python.ts | 10 +++ helpers/types.ts | 1 + .../components/settings/python/settings.py | 10 +++ .../src/controllers/engine/settings.ts | 43 +++++++--- .../nextjs/app/api/chat/engine/settings.ts | 16 ++++ 9 files changed, 171 insertions(+), 14 deletions(-) create mode 100644 .changeset/chilly-worms-speak.md create mode 100644 helpers/providers/mistral.ts diff --git a/.changeset/chilly-worms-speak.md b/.changeset/chilly-worms-speak.md new file mode 100644 index 00000000..45dd1eb5 --- /dev/null +++ b/.changeset/chilly-worms-speak.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +support Mistral as llm and embedding diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index b6f7af86..7b10e23d 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -265,6 +265,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => { }, ] : []), + ...(modelConfig.provider === "mistral" + ? [ + { + name: "MISTRAL_API_KEY", + description: "The Mistral API key to use.", + value: modelConfig.apiKey, + }, + ] + : []), ...(modelConfig.provider === "t-systems" ? [ { diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts index 22e4de3e..b660248c 100644 --- a/helpers/providers/index.ts +++ b/helpers/providers/index.ts @@ -6,6 +6,7 @@ import { askAnthropicQuestions } from "./anthropic"; import { askGeminiQuestions } from "./gemini"; import { askGroqQuestions } from "./groq"; import { askLLMHubQuestions } from "./llmhub"; +import { askMistralQuestions } from "./mistral"; import { askOllamaQuestions } from "./ollama"; import { askOpenAIQuestions } from "./openai"; @@ -32,6 +33,7 @@ export async function askModelConfig({ { title: "Ollama", value: "ollama" }, { title: "Anthropic", value: "anthropic" }, { title: "Gemini", value: "gemini" }, + { title: "Mistral", value: "mistral" }, ]; if (framework === "fastapi") { @@ -64,6 +66,9 @@ export async function askModelConfig({ case "gemini": modelConfig = await askGeminiQuestions({ askModels }); break; + case "mistral": + modelConfig = await askMistralQuestions({ askModels }); + break; case "t-systems": modelConfig = await askLLMHubQuestions({ askModels }); break; diff --git a/helpers/providers/mistral.ts b/helpers/providers/mistral.ts new file mode 100644 index 00000000..b892b748 --- /dev/null +++ b/helpers/providers/mistral.ts @@ -0,0 +1,86 @@ +import ciInfo from "ci-info"; +import prompts from "prompts"; +import { ModelConfigParams } from "."; +import { questionHandlers, toChoice } from "../../questions"; + +const MODELS = ["mistral-tiny", "mistral-small", "mistral-medium"]; +type ModelData = { + dimensions: number; +}; +const EMBEDDING_MODELS: Record<string, ModelData> = { + "mistral-embed": { dimensions: 1024 }, +}; + +const DEFAULT_MODEL = MODELS[0]; +const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0]; +const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions; + +type MistralQuestionsParams = { + apiKey?: string; + askModels: boolean; +}; + +export async function askMistralQuestions({ + askModels, + apiKey, +}: MistralQuestionsParams): Promise<ModelConfigParams> { + const config: ModelConfigParams = { + apiKey, + model: DEFAULT_MODEL, + embeddingModel: DEFAULT_EMBEDDING_MODEL, + dimensions: DEFAULT_DIMENSIONS, + isConfigured(): boolean { + if (config.apiKey) { + return true; + } + if (process.env["MISTRAL_API_KEY"]) { + return true; + } + return false; + }, + }; + + if (!config.apiKey) { + const { key } = await prompts( + { + type: "text", + name: "key", + message: + "Please provide your Mistral API key (or leave blank to use MISTRAL_API_KEY env variable):", + }, + questionHandlers, + ); + config.apiKey = key || process.env.MISTRAL_API_KEY; + } + + // use default model values in CI or if user should not be asked + const useDefaults = ciInfo.isCI || !askModels; + if (!useDefaults) { + const { model } = await prompts( + { + type: "select", + name: "model", + message: "Which LLM model would you like to use?", + choices: MODELS.map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.model = model; + + const { embeddingModel } = await prompts( + { + type: "select", + name: "embeddingModel", + message: "Which embedding model would you like to use?", + choices: Object.keys(EMBEDDING_MODELS).map(toChoice), + initial: 0, + }, + questionHandlers, + ); + config.embeddingModel = embeddingModel; + config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions; + } + + return config; +} diff --git a/helpers/python.ts b/helpers/python.ts index 04f6a11a..186017be 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -173,6 +173,16 @@ const getAdditionalDependencies = ( version: "0.1.6", }); break; + case "mistral": + dependencies.push({ + name: "llama-index-llms-mistralai", + version: "0.1.17", + }); + dependencies.push({ + name: "llama-index-embeddings-mistralai", + version: "0.1.4", + }); + break; case "t-systems": dependencies.push({ name: "llama-index-agent-openai", diff --git a/helpers/types.ts b/helpers/types.ts index 12c62019..ee0ee853 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -7,6 +7,7 @@ export type ModelProvider = | "ollama" | "anthropic" | "gemini" + | "mistral" | "t-systems"; export type ModelConfig = { provider: ModelProvider; diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py index 2f7e4b3a..e0c974cc 100644 --- a/templates/components/settings/python/settings.py +++ b/templates/components/settings/python/settings.py @@ -17,6 +17,8 @@ def init_settings(): init_anthropic() case "gemini": init_gemini() + case "mistral": + init_mistral() case "azure-openai": init_azure_openai() case "t-systems": @@ -149,3 +151,11 @@ def init_gemini(): Settings.llm = Gemini(model=model_name) Settings.embed_model = GeminiEmbedding(model_name=embed_model_name) + + +def init_mistral(): + from llama_index.embeddings.mistralai import MistralAIEmbedding + from llama_index.llms.mistralai import MistralAI + + Settings.llm = MistralAI(model=os.getenv("MODEL")) + Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL")) diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts index afc4f6b8..98160a56 100644 --- a/templates/types/streaming/express/src/controllers/engine/settings.ts +++ b/templates/types/streaming/express/src/controllers/engine/settings.ts @@ -1,10 +1,14 @@ import { + ALL_AVAILABLE_MISTRAL_MODELS, Anthropic, GEMINI_EMBEDDING_MODEL, GEMINI_MODEL, Gemini, GeminiEmbedding, Groq, + MistralAI, + MistralAIEmbedding, + MistralAIEmbeddingModelType, OpenAI, OpenAIEmbedding, Settings, @@ -38,6 +42,9 @@ export const initSettings = async () => { case "gemini": initGemini(); break; + case "mistral": + initMistralAI(); + break; default: initOpenAI(); break; @@ -65,7 +72,6 @@ function initOllama() { const config = { host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434", }; - Settings.llm = new Ollama({ model: process.env.MODEL ?? "", config, @@ -76,19 +82,6 @@ function initOllama() { }); } -function initAnthropic() { - const embedModelMap: Record<string, string> = { - "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", - "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", - }; - Settings.llm = new Anthropic({ - model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS, - }); - Settings.embedModel = new HuggingFaceEmbedding({ - modelType: embedModelMap[process.env.EMBEDDING_MODEL!], - }); -} - function initGroq() { const embedModelMap: Record<string, string> = { "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", @@ -110,6 +103,19 @@ function initGroq() { }); } +function initAnthropic() { + const embedModelMap: Record<string, string> = { + "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", + "all-mpnet-base-v2": "Xenova/all-mpnet-base-v2", + }; + Settings.llm = new Anthropic({ + model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS, + }); + Settings.embedModel = new HuggingFaceEmbedding({ + modelType: embedModelMap[process.env.EMBEDDING_MODEL!], + }); +} + function initGemini() { Settings.llm = new Gemini({ model: process.env.MODEL as GEMINI_MODEL, @@ -118,3 +124,12 @@ function initGemini() { model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL, }); } + +function initMistralAI() { + Settings.llm = new MistralAI({ + model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS, + }); + Settings.embedModel = new MistralAIEmbedding({ + model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType, + }); +} diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts index 38032389..98160a56 100644 --- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts +++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts @@ -1,10 +1,14 @@ import { + ALL_AVAILABLE_MISTRAL_MODELS, Anthropic, GEMINI_EMBEDDING_MODEL, GEMINI_MODEL, Gemini, GeminiEmbedding, Groq, + MistralAI, + MistralAIEmbedding, + MistralAIEmbeddingModelType, OpenAI, OpenAIEmbedding, Settings, @@ -38,6 +42,9 @@ export const initSettings = async () => { case "gemini": initGemini(); break; + case "mistral": + initMistralAI(); + break; default: initOpenAI(); break; @@ -117,3 +124,12 @@ function initGemini() { model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL, }); } + +function initMistralAI() { + Settings.llm = new MistralAI({ + model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS, + }); + Settings.embedModel = new MistralAIEmbedding({ + model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType, + }); +} -- GitLab