Skip to content
Snippets Groups Projects
Unverified Commit fd9fb42a authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: add azure model provider (#184)

parent 92798f73
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Add azure model provider
...@@ -274,6 +274,33 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => { ...@@ -274,6 +274,33 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
}, },
] ]
: []), : []),
...(modelConfig.provider === "azure-openai"
? [
{
name: "AZURE_OPENAI_KEY",
description: "The Azure OpenAI key to use.",
value: modelConfig.apiKey,
},
{
name: "AZURE_OPENAI_ENDPOINT",
description: "The Azure OpenAI endpoint to use.",
},
{
name: "AZURE_OPENAI_API_VERSION",
description: "The Azure OpenAI API version to use.",
},
{
name: "AZURE_OPENAI_LLM_DEPLOYMENT",
description:
"The Azure OpenAI deployment to use for LLM deployment.",
},
{
name: "AZURE_OPENAI_EMBEDDING_DEPLOYMENT",
description:
"The Azure OpenAI deployment to use for embedding deployment.",
},
]
: []),
...(modelConfig.provider === "t-systems" ...(modelConfig.provider === "t-systems"
? [ ? [
{ {
......
import ciInfo from "ci-info";
import prompts from "prompts";
import { ModelConfigParams, ModelConfigQuestionsParams } from ".";
import { questionHandlers } from "../../questions";
const ALL_AZURE_OPENAI_CHAT_MODELS: Record<string, { openAIModel: string }> = {
"gpt-35-turbo": { openAIModel: "gpt-3.5-turbo" },
"gpt-35-turbo-16k": {
openAIModel: "gpt-3.5-turbo-16k",
},
"gpt-4o": { openAIModel: "gpt-4o" },
"gpt-4": { openAIModel: "gpt-4" },
"gpt-4-32k": { openAIModel: "gpt-4-32k" },
"gpt-4-turbo": {
openAIModel: "gpt-4-turbo",
},
"gpt-4-turbo-2024-04-09": {
openAIModel: "gpt-4-turbo",
},
"gpt-4-vision-preview": {
openAIModel: "gpt-4-vision-preview",
},
"gpt-4-1106-preview": {
openAIModel: "gpt-4-1106-preview",
},
"gpt-4o-2024-05-13": {
openAIModel: "gpt-4o-2024-05-13",
},
};
const ALL_AZURE_OPENAI_EMBEDDING_MODELS: Record<
string,
{
dimensions: number;
openAIModel: string;
}
> = {
"text-embedding-ada-002": {
dimensions: 1536,
openAIModel: "text-embedding-ada-002",
},
"text-embedding-3-small": {
dimensions: 1536,
openAIModel: "text-embedding-3-small",
},
"text-embedding-3-large": {
dimensions: 3072,
openAIModel: "text-embedding-3-large",
},
};
const DEFAULT_MODEL = "gpt-4o";
const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
export async function askAzureQuestions({
openAiKey,
askModels,
}: ModelConfigQuestionsParams): Promise<ModelConfigParams> {
const config: ModelConfigParams = {
apiKey: openAiKey,
model: DEFAULT_MODEL,
embeddingModel: DEFAULT_EMBEDDING_MODEL,
dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
isConfigured(): boolean {
// the Azure model provider can't be fully configured as endpoint and deployment names have to be configured with env variables
return false;
},
};
if (!config.apiKey) {
const { key } = await prompts(
{
type: "text",
name: "key",
message: askModels
? "Please provide your Azure OpenAI API key (or leave blank to use AZURE_OPENAI_KEY env variable):"
: "Please provide your Azure OpenAI API key (leave blank to skip):",
validate: (value: string) => {
if (askModels && !value) {
if (process.env.AZURE_OPENAI_KEY) {
return true;
}
return "AZURE_OPENAI_KEY env variable is not set - key is required";
}
return true;
},
},
questionHandlers,
);
config.apiKey = key || process.env.AZURE_OPENAI_KEY;
}
// use default model values in CI or if user should not be asked
const useDefaults = ciInfo.isCI || !askModels;
if (!useDefaults) {
const { model } = await prompts(
{
type: "select",
name: "model",
message: "Which LLM model would you like to use?",
choices: getAvailableModelChoices(),
initial: 0,
},
questionHandlers,
);
config.model = model;
const { embeddingModel } = await prompts(
{
type: "select",
name: "embeddingModel",
message: "Which embedding model would you like to use?",
choices: getAvailableEmbeddingModelChoices(),
initial: 0,
},
questionHandlers,
);
config.embeddingModel = embeddingModel;
config.dimensions = getDimensions(embeddingModel);
}
return config;
}
function getAvailableModelChoices() {
return Object.keys(ALL_AZURE_OPENAI_CHAT_MODELS).map((key) => ({
title: key,
value: key,
}));
}
function getAvailableEmbeddingModelChoices() {
return Object.keys(ALL_AZURE_OPENAI_EMBEDDING_MODELS).map((key) => ({
title: key,
value: key,
}));
}
function getDimensions(modelName: string) {
return ALL_AZURE_OPENAI_EMBEDDING_MODELS[modelName].dimensions;
}
...@@ -3,6 +3,7 @@ import prompts from "prompts"; ...@@ -3,6 +3,7 @@ import prompts from "prompts";
import { questionHandlers } from "../../questions"; import { questionHandlers } from "../../questions";
import { ModelConfig, ModelProvider, TemplateFramework } from "../types"; import { ModelConfig, ModelProvider, TemplateFramework } from "../types";
import { askAnthropicQuestions } from "./anthropic"; import { askAnthropicQuestions } from "./anthropic";
import { askAzureQuestions } from "./azure";
import { askGeminiQuestions } from "./gemini"; import { askGeminiQuestions } from "./gemini";
import { askGroqQuestions } from "./groq"; import { askGroqQuestions } from "./groq";
import { askLLMHubQuestions } from "./llmhub"; import { askLLMHubQuestions } from "./llmhub";
...@@ -34,6 +35,7 @@ export async function askModelConfig({ ...@@ -34,6 +35,7 @@ export async function askModelConfig({
{ title: "Anthropic", value: "anthropic" }, { title: "Anthropic", value: "anthropic" },
{ title: "Gemini", value: "gemini" }, { title: "Gemini", value: "gemini" },
{ title: "Mistral", value: "mistral" }, { title: "Mistral", value: "mistral" },
{ title: "AzureOpenAI", value: "azure-openai" },
]; ];
if (framework === "fastapi") { if (framework === "fastapi") {
...@@ -69,6 +71,9 @@ export async function askModelConfig({ ...@@ -69,6 +71,9 @@ export async function askModelConfig({
case "mistral": case "mistral":
modelConfig = await askMistralQuestions({ askModels }); modelConfig = await askMistralQuestions({ askModels });
break; break;
case "azure-openai":
modelConfig = await askAzureQuestions({ askModels });
break;
case "t-systems": case "t-systems":
modelConfig = await askLLMHubQuestions({ askModels }); modelConfig = await askLLMHubQuestions({ askModels });
break; break;
......
...@@ -193,6 +193,16 @@ const getAdditionalDependencies = ( ...@@ -193,6 +193,16 @@ const getAdditionalDependencies = (
version: "0.1.4", version: "0.1.4",
}); });
break; break;
case "azure-openai":
dependencies.push({
name: "llama-index-llms-azure-openai",
version: "0.1.10",
});
dependencies.push({
name: "llama-index-embeddings-azure-openai",
version: "0.1.11",
});
break;
case "t-systems": case "t-systems":
dependencies.push({ dependencies.push({
name: "llama-index-agent-openai", name: "llama-index-agent-openai",
......
...@@ -8,6 +8,7 @@ export type ModelProvider = ...@@ -8,6 +8,7 @@ export type ModelProvider =
| "anthropic" | "anthropic"
| "gemini" | "gemini"
| "mistral" | "mistral"
| "azure-openai"
| "t-systems"; | "t-systems";
export type ModelConfig = { export type ModelConfig = {
provider: ModelProvider; provider: ModelProvider;
......
...@@ -78,24 +78,30 @@ def init_azure_openai(): ...@@ -78,24 +78,30 @@ def init_azure_openai():
llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
max_tokens = os.getenv("LLM_MAX_TOKENS") max_tokens = os.getenv("LLM_MAX_TOKENS")
api_key = os.getenv("AZURE_OPENAI_API_KEY") temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
llm_config = {
"api_key": api_key,
"deployment_name": llm_deployment,
"model": os.getenv("MODEL"),
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
"max_tokens": int(max_tokens) if max_tokens is not None else None,
}
Settings.llm = AzureOpenAI(**llm_config)
dimensions = os.getenv("EMBEDDING_DIM") dimensions = os.getenv("EMBEDDING_DIM")
embedding_config = {
"api_key": api_key, azure_config = {
"deployment_name": embedding_deployment, "api_key": os.getenv("AZURE_OPENAI_KEY"),
"model": os.getenv("EMBEDDING_MODEL"), "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
"dimensions": int(dimensions) if dimensions is not None else None, "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION"),
} }
Settings.embed_model = AzureOpenAIEmbedding(**embedding_config)
Settings.llm = AzureOpenAI(
model=os.getenv("MODEL"),
max_tokens=int(max_tokens) if max_tokens is not None else None,
temperature=float(temperature),
deployment_name=llm_deployment,
**azure_config,
)
Settings.embed_model = AzureOpenAIEmbedding(
model=os.getenv("EMBEDDING_MODEL"),
dimensions=int(dimensions) if dimensions is not None else None,
deployment_name=embedding_deployment,
**azure_config,
)
def init_fastembed(): def init_fastembed():
...@@ -108,7 +114,7 @@ def init_fastembed(): ...@@ -108,7 +114,7 @@ def init_fastembed():
# Small and multilingual # Small and multilingual
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
# Large and multilingual # Large and multilingual
"paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501 "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
} }
# This will download the model automatically if it is not already downloaded # This will download the model automatically if it is not already downloaded
...@@ -116,6 +122,7 @@ def init_fastembed(): ...@@ -116,6 +122,7 @@ def init_fastembed():
model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
) )
def init_groq(): def init_groq():
from llama_index.llms.groq import Groq from llama_index.llms.groq import Groq
...@@ -125,7 +132,6 @@ def init_groq(): ...@@ -125,7 +132,6 @@ def init_groq():
"mixtral-8x7b": "mixtral-8x7b-32768", "mixtral-8x7b": "mixtral-8x7b-32768",
} }
Settings.llm = Groq(model=model_map[os.getenv("MODEL")]) Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
# Groq does not provide embeddings, so we use FastEmbed instead # Groq does not provide embeddings, so we use FastEmbed instead
init_fastembed() init_fastembed()
......
...@@ -45,6 +45,9 @@ export const initSettings = async () => { ...@@ -45,6 +45,9 @@ export const initSettings = async () => {
case "mistral": case "mistral":
initMistralAI(); initMistralAI();
break; break;
case "azure-openai":
initAzureOpenAI();
break;
default: default:
initOpenAI(); initOpenAI();
break; break;
...@@ -68,6 +71,53 @@ function initOpenAI() { ...@@ -68,6 +71,53 @@ function initOpenAI() {
}); });
} }
function initAzureOpenAI() {
// Map Azure OpenAI model names to OpenAI model names (only for TS)
const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
"gpt-35-turbo": "gpt-3.5-turbo",
"gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
"gpt-4o": "gpt-4o",
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-4-turbo": "gpt-4-turbo",
"gpt-4-turbo-2024-04-09": "gpt-4-turbo",
"gpt-4-vision-preview": "gpt-4-vision-preview",
"gpt-4-1106-preview": "gpt-4-1106-preview",
"gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
};
const azureConfig = {
apiKey: process.env.AZURE_OPENAI_KEY,
endpoint: process.env.AZURE_OPENAI_ENDPOINT,
apiVersion:
process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
};
Settings.llm = new OpenAI({
model:
AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? "gpt-35-turbo"] ??
"gpt-3.5-turbo",
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
},
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
dimensions: process.env.EMBEDDING_DIM
? parseInt(process.env.EMBEDDING_DIM)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
},
});
}
function initOllama() { function initOllama() {
const config = { const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434", host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
......
...@@ -45,6 +45,9 @@ export const initSettings = async () => { ...@@ -45,6 +45,9 @@ export const initSettings = async () => {
case "mistral": case "mistral":
initMistralAI(); initMistralAI();
break; break;
case "azure-openai":
initAzureOpenAI();
break;
default: default:
initOpenAI(); initOpenAI();
break; break;
...@@ -55,7 +58,7 @@ export const initSettings = async () => { ...@@ -55,7 +58,7 @@ export const initSettings = async () => {
function initOpenAI() { function initOpenAI() {
Settings.llm = new OpenAI({ Settings.llm = new OpenAI({
model: process.env.MODEL ?? "gpt-3.5-turbo", model: process.env.MODEL ?? "gpt-4o-mini",
maxTokens: process.env.LLM_MAX_TOKENS maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS) ? Number(process.env.LLM_MAX_TOKENS)
: undefined, : undefined,
...@@ -68,6 +71,53 @@ function initOpenAI() { ...@@ -68,6 +71,53 @@ function initOpenAI() {
}); });
} }
function initAzureOpenAI() {
// Map Azure OpenAI model names to OpenAI model names (only for TS)
const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
"gpt-35-turbo": "gpt-3.5-turbo",
"gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
"gpt-4o": "gpt-4o",
"gpt-4": "gpt-4",
"gpt-4-32k": "gpt-4-32k",
"gpt-4-turbo": "gpt-4-turbo",
"gpt-4-turbo-2024-04-09": "gpt-4-turbo",
"gpt-4-vision-preview": "gpt-4-vision-preview",
"gpt-4-1106-preview": "gpt-4-1106-preview",
"gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
};
const azureConfig = {
apiKey: process.env.AZURE_OPENAI_KEY,
endpoint: process.env.AZURE_OPENAI_ENDPOINT,
apiVersion:
process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
};
Settings.llm = new OpenAI({
model:
AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? "gpt-35-turbo"] ??
"gpt-3.5-turbo",
maxTokens: process.env.LLM_MAX_TOKENS
? Number(process.env.LLM_MAX_TOKENS)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
},
});
Settings.embedModel = new OpenAIEmbedding({
model: process.env.EMBEDDING_MODEL,
dimensions: process.env.EMBEDDING_DIM
? parseInt(process.env.EMBEDDING_DIM)
: undefined,
azure: {
...azureConfig,
deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
},
});
}
function initOllama() { function initOllama() {
const config = { const config = {
host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434", host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment