From fd9fb42ace8aeb8fdd4d525b3d4a0de35cb9abdf Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Fri, 26 Jul 2024 17:32:14 +0700
Subject: [PATCH] feat: add azure model provider (#184)

---
 .changeset/smooth-points-float.md             |   5 +
 helpers/env-variables.ts                      |  27 ++++
 helpers/providers/azure.ts                    | 141 ++++++++++++++++++
 helpers/providers/index.ts                    |   5 +
 helpers/python.ts                             |  10 ++
 helpers/types.ts                              |   1 +
 .../components/settings/python/settings.py    |  42 +++---
 .../src/controllers/engine/settings.ts        |  50 +++++++
 .../nextjs/app/api/chat/engine/settings.ts    |  52 ++++++-
 9 files changed, 314 insertions(+), 19 deletions(-)
 create mode 100644 .changeset/smooth-points-float.md
 create mode 100644 helpers/providers/azure.ts

diff --git a/.changeset/smooth-points-float.md b/.changeset/smooth-points-float.md
new file mode 100644
index 00000000..f0882f52
--- /dev/null
+++ b/.changeset/smooth-points-float.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Add azure model provider
diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts
index 7b10e23d..ef232834 100644
--- a/helpers/env-variables.ts
+++ b/helpers/env-variables.ts
@@ -274,6 +274,33 @@ const getModelEnvs = (modelConfig: ModelConfig): EnvVar[] => {
           },
         ]
       : []),
+    ...(modelConfig.provider === "azure-openai"
+      ? [
+          {
+            name: "AZURE_OPENAI_KEY",
+            description: "The Azure OpenAI key to use.",
+            value: modelConfig.apiKey,
+          },
+          {
+            name: "AZURE_OPENAI_ENDPOINT",
+            description: "The Azure OpenAI endpoint to use.",
+          },
+          {
+            name: "AZURE_OPENAI_API_VERSION",
+            description: "The Azure OpenAI API version to use.",
+          },
+          {
+            name: "AZURE_OPENAI_LLM_DEPLOYMENT",
+            description:
+              "The Azure OpenAI deployment to use for LLM deployment.",
+          },
+          {
+            name: "AZURE_OPENAI_EMBEDDING_DEPLOYMENT",
+            description:
+              "The Azure OpenAI deployment to use for embedding deployment.",
+          },
+        ]
+      : []),
     ...(modelConfig.provider === "t-systems"
       ? [
           {
diff --git a/helpers/providers/azure.ts b/helpers/providers/azure.ts
new file mode 100644
index 00000000..b343d3b3
--- /dev/null
+++ b/helpers/providers/azure.ts
@@ -0,0 +1,141 @@
+import ciInfo from "ci-info";
+import prompts from "prompts";
+import { ModelConfigParams, ModelConfigQuestionsParams } from ".";
+import { questionHandlers } from "../../questions";
+
+const ALL_AZURE_OPENAI_CHAT_MODELS: Record<string, { openAIModel: string }> = {
+  "gpt-35-turbo": { openAIModel: "gpt-3.5-turbo" },
+  "gpt-35-turbo-16k": {
+    openAIModel: "gpt-3.5-turbo-16k",
+  },
+  "gpt-4o": { openAIModel: "gpt-4o" },
+  "gpt-4": { openAIModel: "gpt-4" },
+  "gpt-4-32k": { openAIModel: "gpt-4-32k" },
+  "gpt-4-turbo": {
+    openAIModel: "gpt-4-turbo",
+  },
+  "gpt-4-turbo-2024-04-09": {
+    openAIModel: "gpt-4-turbo",
+  },
+  "gpt-4-vision-preview": {
+    openAIModel: "gpt-4-vision-preview",
+  },
+  "gpt-4-1106-preview": {
+    openAIModel: "gpt-4-1106-preview",
+  },
+  "gpt-4o-2024-05-13": {
+    openAIModel: "gpt-4o-2024-05-13",
+  },
+};
+
+const ALL_AZURE_OPENAI_EMBEDDING_MODELS: Record<
+  string,
+  {
+    dimensions: number;
+    openAIModel: string;
+  }
+> = {
+  "text-embedding-ada-002": {
+    dimensions: 1536,
+    openAIModel: "text-embedding-ada-002",
+  },
+  "text-embedding-3-small": {
+    dimensions: 1536,
+    openAIModel: "text-embedding-3-small",
+  },
+  "text-embedding-3-large": {
+    dimensions: 3072,
+    openAIModel: "text-embedding-3-large",
+  },
+};
+
+const DEFAULT_MODEL = "gpt-4o";
+const DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large";
+
+export async function askAzureQuestions({
+  openAiKey,
+  askModels,
+}: ModelConfigQuestionsParams): Promise<ModelConfigParams> {
+  const config: ModelConfigParams = {
+    apiKey: openAiKey,
+    model: DEFAULT_MODEL,
+    embeddingModel: DEFAULT_EMBEDDING_MODEL,
+    dimensions: getDimensions(DEFAULT_EMBEDDING_MODEL),
+    isConfigured(): boolean {
+      // the Azure model provider can't be fully configured as endpoint and deployment names have to be configured with env variables
+      return false;
+    },
+  };
+
+  if (!config.apiKey) {
+    const { key } = await prompts(
+      {
+        type: "text",
+        name: "key",
+        message: askModels
+          ? "Please provide your Azure OpenAI API key (or leave blank to use AZURE_OPENAI_KEY env variable):"
+          : "Please provide your Azure OpenAI API key (leave blank to skip):",
+        validate: (value: string) => {
+          if (askModels && !value) {
+            if (process.env.AZURE_OPENAI_KEY) {
+              return true;
+            }
+            return "AZURE_OPENAI_KEY env variable is not set - key is required";
+          }
+          return true;
+        },
+      },
+      questionHandlers,
+    );
+    config.apiKey = key || process.env.AZURE_OPENAI_KEY;
+  }
+
+  // use default model values in CI or if user should not be asked
+  const useDefaults = ciInfo.isCI || !askModels;
+  if (!useDefaults) {
+    const { model } = await prompts(
+      {
+        type: "select",
+        name: "model",
+        message: "Which LLM model would you like to use?",
+        choices: getAvailableModelChoices(),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.model = model;
+
+    const { embeddingModel } = await prompts(
+      {
+        type: "select",
+        name: "embeddingModel",
+        message: "Which embedding model would you like to use?",
+        choices: getAvailableEmbeddingModelChoices(),
+        initial: 0,
+      },
+      questionHandlers,
+    );
+    config.embeddingModel = embeddingModel;
+    config.dimensions = getDimensions(embeddingModel);
+  }
+
+  return config;
+}
+
+function getAvailableModelChoices() {
+  return Object.keys(ALL_AZURE_OPENAI_CHAT_MODELS).map((key) => ({
+    title: key,
+    value: key,
+  }));
+}
+
+function getAvailableEmbeddingModelChoices() {
+  return Object.keys(ALL_AZURE_OPENAI_EMBEDDING_MODELS).map((key) => ({
+    title: key,
+    value: key,
+  }));
+}
+
+function getDimensions(modelName: string) {
+  return ALL_AZURE_OPENAI_EMBEDDING_MODELS[modelName].dimensions;
+}
diff --git a/helpers/providers/index.ts b/helpers/providers/index.ts
index b660248c..c19efaa4 100644
--- a/helpers/providers/index.ts
+++ b/helpers/providers/index.ts
@@ -3,6 +3,7 @@ import prompts from "prompts";
 import { questionHandlers } from "../../questions";
 import { ModelConfig, ModelProvider, TemplateFramework } from "../types";
 import { askAnthropicQuestions } from "./anthropic";
+import { askAzureQuestions } from "./azure";
 import { askGeminiQuestions } from "./gemini";
 import { askGroqQuestions } from "./groq";
 import { askLLMHubQuestions } from "./llmhub";
@@ -34,6 +35,7 @@ export async function askModelConfig({
       { title: "Anthropic", value: "anthropic" },
       { title: "Gemini", value: "gemini" },
       { title: "Mistral", value: "mistral" },
+      { title: "AzureOpenAI", value: "azure-openai" },
     ];
 
     if (framework === "fastapi") {
@@ -69,6 +71,9 @@ export async function askModelConfig({
     case "mistral":
       modelConfig = await askMistralQuestions({ askModels });
       break;
+    case "azure-openai":
+      modelConfig = await askAzureQuestions({ askModels });
+      break;
     case "t-systems":
       modelConfig = await askLLMHubQuestions({ askModels });
       break;
diff --git a/helpers/python.ts b/helpers/python.ts
index c3f8037c..3092bdbd 100644
--- a/helpers/python.ts
+++ b/helpers/python.ts
@@ -193,6 +193,16 @@ const getAdditionalDependencies = (
         version: "0.1.4",
       });
       break;
+    case "azure-openai":
+      dependencies.push({
+        name: "llama-index-llms-azure-openai",
+        version: "0.1.10",
+      });
+      dependencies.push({
+        name: "llama-index-embeddings-azure-openai",
+        version: "0.1.11",
+      });
+      break;
     case "t-systems":
       dependencies.push({
         name: "llama-index-agent-openai",
diff --git a/helpers/types.ts b/helpers/types.ts
index aadd4e65..9dc9686d 100644
--- a/helpers/types.ts
+++ b/helpers/types.ts
@@ -8,6 +8,7 @@ export type ModelProvider =
   | "anthropic"
   | "gemini"
   | "mistral"
+  | "azure-openai"
   | "t-systems";
 export type ModelConfig = {
   provider: ModelProvider;
diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py
index ce427645..4d50429c 100644
--- a/templates/components/settings/python/settings.py
+++ b/templates/components/settings/python/settings.py
@@ -78,24 +78,30 @@ def init_azure_openai():
     llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
     embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
     max_tokens = os.getenv("LLM_MAX_TOKENS")
-    api_key = os.getenv("AZURE_OPENAI_API_KEY")
-    llm_config = {
-        "api_key": api_key,
-        "deployment_name": llm_deployment,
-        "model": os.getenv("MODEL"),
-        "temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
-        "max_tokens": int(max_tokens) if max_tokens is not None else None,
-    }
-    Settings.llm = AzureOpenAI(**llm_config)
-
+    temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
     dimensions = os.getenv("EMBEDDING_DIM")
-    embedding_config = {
-        "api_key": api_key,
-        "deployment_name": embedding_deployment,
-        "model": os.getenv("EMBEDDING_MODEL"),
-        "dimensions": int(dimensions) if dimensions is not None else None,
+
+    azure_config = {
+        "api_key": os.getenv("AZURE_OPENAI_KEY"),
+        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
+        "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
+        or os.getenv("OPENAI_API_VERSION"),
     }
-    Settings.embed_model = AzureOpenAIEmbedding(**embedding_config)
+
+    Settings.llm = AzureOpenAI(
+        model=os.getenv("MODEL"),
+        max_tokens=int(max_tokens) if max_tokens is not None else None,
+        temperature=float(temperature),
+        deployment_name=llm_deployment,
+        **azure_config,
+    )
+
+    Settings.embed_model = AzureOpenAIEmbedding(
+        model=os.getenv("EMBEDDING_MODEL"),
+        dimensions=int(dimensions) if dimensions is not None else None,
+        deployment_name=embedding_deployment,
+        **azure_config,
+    )
 
 
 def init_fastembed():
@@ -108,7 +114,7 @@ def init_fastembed():
         # Small and multilingual
         "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
         # Large and multilingual
-        "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",   # noqa: E501
+        "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",  # noqa: E501
     }
 
     # This will download the model automatically if it is not already downloaded
@@ -116,6 +122,7 @@ def init_fastembed():
         model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
     )
 
+
 def init_groq():
     from llama_index.llms.groq import Groq
 
@@ -125,7 +132,6 @@ def init_groq():
         "mixtral-8x7b": "mixtral-8x7b-32768",
     }
 
-
     Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
     # Groq does not provide embeddings, so we use FastEmbed instead
     init_fastembed()
diff --git a/templates/types/streaming/express/src/controllers/engine/settings.ts b/templates/types/streaming/express/src/controllers/engine/settings.ts
index 2207552a..28761d26 100644
--- a/templates/types/streaming/express/src/controllers/engine/settings.ts
+++ b/templates/types/streaming/express/src/controllers/engine/settings.ts
@@ -45,6 +45,9 @@ export const initSettings = async () => {
     case "mistral":
       initMistralAI();
       break;
+    case "azure-openai":
+      initAzureOpenAI();
+      break;
     default:
       initOpenAI();
       break;
@@ -68,6 +71,53 @@ function initOpenAI() {
   });
 }
 
+function initAzureOpenAI() {
+  // Map Azure OpenAI model names to OpenAI model names (only for TS)
+  const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
+    "gpt-35-turbo": "gpt-3.5-turbo",
+    "gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
+    "gpt-4o": "gpt-4o",
+    "gpt-4": "gpt-4",
+    "gpt-4-32k": "gpt-4-32k",
+    "gpt-4-turbo": "gpt-4-turbo",
+    "gpt-4-turbo-2024-04-09": "gpt-4-turbo",
+    "gpt-4-vision-preview": "gpt-4-vision-preview",
+    "gpt-4-1106-preview": "gpt-4-1106-preview",
+    "gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
+  };
+
+  const azureConfig = {
+    apiKey: process.env.AZURE_OPENAI_KEY,
+    endpoint: process.env.AZURE_OPENAI_ENDPOINT,
+    apiVersion:
+      process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
+  };
+
+  Settings.llm = new OpenAI({
+    model:
+      AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? "gpt-35-turbo"] ??
+      "gpt-3.5-turbo",
+    maxTokens: process.env.LLM_MAX_TOKENS
+      ? Number(process.env.LLM_MAX_TOKENS)
+      : undefined,
+    azure: {
+      ...azureConfig,
+      deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
+    },
+  });
+
+  Settings.embedModel = new OpenAIEmbedding({
+    model: process.env.EMBEDDING_MODEL,
+    dimensions: process.env.EMBEDDING_DIM
+      ? parseInt(process.env.EMBEDDING_DIM)
+      : undefined,
+    azure: {
+      ...azureConfig,
+      deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
+    },
+  });
+}
+
 function initOllama() {
   const config = {
     host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
diff --git a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
index 98160a56..28761d26 100644
--- a/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
+++ b/templates/types/streaming/nextjs/app/api/chat/engine/settings.ts
@@ -45,6 +45,9 @@ export const initSettings = async () => {
     case "mistral":
       initMistralAI();
       break;
+    case "azure-openai":
+      initAzureOpenAI();
+      break;
     default:
       initOpenAI();
       break;
@@ -55,7 +58,7 @@ export const initSettings = async () => {
 
 function initOpenAI() {
   Settings.llm = new OpenAI({
-    model: process.env.MODEL ?? "gpt-3.5-turbo",
+    model: process.env.MODEL ?? "gpt-4o-mini",
     maxTokens: process.env.LLM_MAX_TOKENS
       ? Number(process.env.LLM_MAX_TOKENS)
       : undefined,
@@ -68,6 +71,53 @@ function initOpenAI() {
   });
 }
 
+function initAzureOpenAI() {
+  // Map Azure OpenAI model names to OpenAI model names (only for TS)
+  const AZURE_OPENAI_MODEL_MAP: Record<string, string> = {
+    "gpt-35-turbo": "gpt-3.5-turbo",
+    "gpt-35-turbo-16k": "gpt-3.5-turbo-16k",
+    "gpt-4o": "gpt-4o",
+    "gpt-4": "gpt-4",
+    "gpt-4-32k": "gpt-4-32k",
+    "gpt-4-turbo": "gpt-4-turbo",
+    "gpt-4-turbo-2024-04-09": "gpt-4-turbo",
+    "gpt-4-vision-preview": "gpt-4-vision-preview",
+    "gpt-4-1106-preview": "gpt-4-1106-preview",
+    "gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
+  };
+
+  const azureConfig = {
+    apiKey: process.env.AZURE_OPENAI_KEY,
+    endpoint: process.env.AZURE_OPENAI_ENDPOINT,
+    apiVersion:
+      process.env.AZURE_OPENAI_API_VERSION || process.env.OPENAI_API_VERSION,
+  };
+
+  Settings.llm = new OpenAI({
+    model:
+      AZURE_OPENAI_MODEL_MAP[process.env.MODEL ?? "gpt-35-turbo"] ??
+      "gpt-3.5-turbo",
+    maxTokens: process.env.LLM_MAX_TOKENS
+      ? Number(process.env.LLM_MAX_TOKENS)
+      : undefined,
+    azure: {
+      ...azureConfig,
+      deployment: process.env.AZURE_OPENAI_LLM_DEPLOYMENT,
+    },
+  });
+
+  Settings.embedModel = new OpenAIEmbedding({
+    model: process.env.EMBEDDING_MODEL,
+    dimensions: process.env.EMBEDDING_DIM
+      ? parseInt(process.env.EMBEDDING_DIM)
+      : undefined,
+    azure: {
+      ...azureConfig,
+      deployment: process.env.AZURE_OPENAI_EMBEDDING_DEPLOYMENT,
+    },
+  });
+}
+
 function initOllama() {
   const config = {
     host: process.env.OLLAMA_BASE_URL ?? "http://127.0.0.1:11434",
-- 
GitLab