Skip to content
Snippets Groups Projects
Unverified Commit d1026ea7 authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: support mistral as llm and embedding (#155)

parent 791ca7c9
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
support Mistral as llm and embedding
...@@ -265,6 +265,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => { ...@@ -265,6 +265,15 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
}, },
] ]
: []), : []),
...(modelConfig.provider === "mistral"
? [
{
name: "MISTRAL_API_KEY",
description: "The Mistral API key to use.",
value: modelConfig.apiKey,
},
]
: []),
...(modelConfig.provider === "t-systems" ...(modelConfig.provider === "t-systems"
? [ ? [
{ {
......
...@@ -6,6 +6,7 @@ import { askAnthropicQuestions } from "./anthropic"; ...@@ -6,6 +6,7 @@ import { askAnthropicQuestions } from "./anthropic";
import { askGeminiQuestions } from "./gemini"; import { askGeminiQuestions } from "./gemini";
import { askGroqQuestions } from "./groq"; import { askGroqQuestions } from "./groq";
import { askLLMHubQuestions } from "./llmhub"; import { askLLMHubQuestions } from "./llmhub";
import { askMistralQuestions } from "./mistral";
import { askOllamaQuestions } from "./ollama"; import { askOllamaQuestions } from "./ollama";
import { askOpenAIQuestions } from "./openai"; import { askOpenAIQuestions } from "./openai";
...@@ -32,6 +33,7 @@ export async function askModelConfig({ ...@@ -32,6 +33,7 @@ export async function askModelConfig({
{ title: "Ollama", value: "ollama" }, { title: "Ollama", value: "ollama" },
{ title: "Anthropic", value: "anthropic" }, { title: "Anthropic", value: "anthropic" },
{ title: "Gemini", value: "gemini" }, { title: "Gemini", value: "gemini" },
{ title: "Mistral", value: "mistral" },
]; ];
if (framework === "fastapi") { if (framework === "fastapi") {
...@@ -64,6 +66,9 @@ export async function askModelConfig({ ...@@ -64,6 +66,9 @@ export async function askModelConfig({
case "gemini": case "gemini":
modelConfig = await askGeminiQuestions({ askModels }); modelConfig = await askGeminiQuestions({ askModels });
break; break;
case "mistral":
modelConfig = await askMistralQuestions({ askModels });
break;
case "t-systems": case "t-systems":
modelConfig = await askLLMHubQuestions({ askModels }); modelConfig = await askLLMHubQuestions({ askModels });
break; break;
......
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams } from ".";
import { questionHandlers, toChoice } from "../../questions";
const MODELS = ["mistral-tiny", "mistral-small", "mistral-medium"];
type ModelData = {
dimensions: number;
};
const EMBEDDING_MODELS: Record<string, ModelData> = {
"mistral-embed": { dimensions: 1024 },
};
const DEFAULT_MODEL = MODELS[0];
const DEFAULT_EMBEDDING_MODEL = Object.keys(EMBEDDING_MODELS)[0];
const DEFAULT_DIMENSIONS = Object.values(EMBEDDING_MODELS)[0].dimensions;
type MistralQuestionsParams = {
apiKey?: string;
askModels: boolean;
};
export async function askMistralQuestions({
askModels,
apiKey,
}: MistralQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: DEFAULT_DIMENSIONS,
isConfigured(): boolean {
if (config.apiKey) {
return true;
}
if (process.env["MISTRAL_API_KEY"]) {
return true;
}
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message:
"Please provide your Mistral API key (or leave blank to use MISTRAL_API_KEY env variable):",
},
questionHandlers,
);
config.apiKey = key || process.env.MISTRAL_API_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: MODELS.map(toChoice),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: Object.keys(EMBEDDING_MODELS).map(toChoice),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions = EMBEDDING_MODELS[embeddingModel].dimensions;
}
return config;
}
...@@ -173,6 +173,16 @@ const getAdditionalDependencies = ( ...@@ -173,6 +173,16 @@ const getAdditionalDependencies = (
version: "0.1.6", version: "0.1.6",
}); });
break; break;
case "mistral":
dependencies.push({
name: "llama-index-llms-mistralai",
version: "0.1.17",
});
dependencies.push({
name: "llama-index-embeddings-mistralai",
version: "0.1.4",
});
break;
case "t-systems": case "t-systems":
dependencies.push({ dependencies.push({
name: "llama-index-agent-openai", name: "llama-index-agent-openai",
......
...@@ -7,6 +7,7 @@ export type ModelProvider = ...@@ -7,6 +7,7 @@ export type ModelProvider =
| "ollama" | "ollama"
| "anthropic" | "anthropic"
| "gemini" | "gemini"
| "mistral"
| "t-systems"; | "t-systems";
export type ModelConfig = { export type ModelConfig = {
provider: ModelProvider; provider: ModelProvider;
......
...@@ -17,6 +17,8 @@ def init_settings(): ...@@ -17,6 +17,8 @@ def init_settings():
init_anthropic() init_anthropic()
case "gemini": case "gemini":
init_gemini() init_gemini()
case "mistral":
init_mistral()
case "azure-openai": case "azure-openai":
init_azure_openai() init_azure_openai()
case "t-systems": case "t-systems":
...@@ -149,3 +151,11 @@ def init_gemini(): ...@@ -149,3 +151,11 @@ def init_gemini():
Settings.llm = Gemini(model=model_name) Settings.llm = Gemini(model=model_name)
Settings.embed_model = GeminiEmbedding(model_name=embed_model_name) Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
def init_mistral():
from llama_index.embeddings.mistralai import MistralAIEmbedding
from llama_index.llms.mistralai import MistralAI
Settings.llm = MistralAI(model=os.getenv("MODEL"))
Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
import { import {
ALL_AVAILABLE_MISTRAL_MODELS,
Anthropic, Anthropic,
GEMINI_EMBEDDING_MODEL, GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL, GEMINI_MODEL,
Gemini, Gemini,
GeminiEmbedding, GeminiEmbedding,
Groq, Groq,
MistralAI,
MistralAIEmbedding,
MistralAIEmbeddingModelType,
OpenAI, OpenAI,
OpenAIEmbedding, OpenAIEmbedding,
Settings, Settings,
...@@ -38,6 +42,9 @@ export const initSettings = async () => { ...@@ -38,6 +42,9 @@ export const initSettings = async () => {
case "gemini": case "gemini":
initGemini(); initGemini();
break; break;
case "mistral":
initMistralAI();
break;
default: default:
initOpenAI(); initOpenAI();
break; break;
...@@ -65,7 +72,6 @@ function initOllama() { ...@@ -65,7 +72,6 @@ function initOllama() {
const config = { const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434", host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
}; };
Settings.llm = new Ollama({ Settings.llm = new Ollama({
model: process.env.MODEL ?? "", model: process.env.MODEL ?? "",
config, config,
...@@ -76,19 +82,6 @@ function initOllama() { ...@@ -76,19 +82,6 @@ function initOllama() {
}); });
} }
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGroq() { function initGroq() {
const embedModelMap: Record<string, string> = { const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2", "all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
...@@ -110,6 +103,19 @@ function initGroq() { ...@@ -110,6 +103,19 @@ function initGroq() {
}); });
} }
function initAnthropic() {
const embedModelMap: Record<string, string> = {
"all-MiniLM-L6-v2": "Xenova/all-MiniLM-L6-v2",
"all-mpnet-base-v2": "Xenova/all-mpnet-base-v2",
};
Settings.llm = new Anthropic({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_ANTHROPIC_MODELS,
});
Settings.embedModel = new HuggingFaceEmbedding({
modelType: embedModelMap[process.env.EMBEDDING_MODEL!],
});
}
function initGemini() { function initGemini() {
Settings.llm = new Gemini({ Settings.llm = new Gemini({
model: process.env.MODEL as GEMINI_MODEL, model: process.env.MODEL as GEMINI_MODEL,
...@@ -118,3 +124,12 @@ function initGemini() { ...@@ -118,3 +124,12 @@ function initGemini() {
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL, model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
}); });
} }
function initMistralAI() {
Settings.llm = new MistralAI({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
});
Settings.embedModel = new MistralAIEmbedding({
model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
});
}
import { import {
ALL_AVAILABLE_MISTRAL_MODELS,
Anthropic, Anthropic,
GEMINI_EMBEDDING_MODEL, GEMINI_EMBEDDING_MODEL,
GEMINI_MODEL, GEMINI_MODEL,
Gemini, Gemini,
GeminiEmbedding, GeminiEmbedding,
Groq, Groq,
MistralAI,
MistralAIEmbedding,
MistralAIEmbeddingModelType,
OpenAI, OpenAI,
OpenAIEmbedding, OpenAIEmbedding,
Settings, Settings,
...@@ -38,6 +42,9 @@ export const initSettings = async () => { ...@@ -38,6 +42,9 @@ export const initSettings = async () => {
case "gemini": case "gemini":
initGemini(); initGemini();
break; break;
case "mistral":
initMistralAI();
break;
default: default:
initOpenAI(); initOpenAI();
break; break;
...@@ -117,3 +124,12 @@ function initGemini() { ...@@ -117,3 +124,12 @@ function initGemini() {
model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL, model: process.env.EMBEDDING_MODEL as GEMINI_EMBEDDING_MODEL,
}); });
} }
function initMistralAI() {
Settings.llm = new MistralAI({
model: process.env.MODEL as keyof typeof ALL_AVAILABLE_MISTRAL_MODELS,
});
Settings.embedModel = new MistralAIEmbedding({
model: process.env.EMBEDDING_MODEL as MistralAIEmbeddingModelType,
});
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment