From 5fe2d519d2f384edf2b888be2498ff6eefddc611 Mon Sep 17 00:00:00 2001
From: Huu Le <39040748+leehuwuj@users.noreply.github.com>
Date: Tue, 4 Jun 2024 16:14:21 +0700
Subject: [PATCH] chore: Add Azure OpenAI model provider python (#110)

---
 .../types/streaming/fastapi/app/settings.py   | 51 +++++++++++++++----
 1 file changed, 41 insertions(+), 10 deletions(-)

diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py
index 90cc0918..0b45f4df 100644
--- a/templates/types/streaming/fastapi/app/settings.py
+++ b/templates/types/streaming/fastapi/app/settings.py
@@ -5,16 +5,19 @@ from llama_index.core.settings import Settings
 
 def init_settings():
     model_provider = os.getenv("MODEL_PROVIDER")
-    if model_provider == "openai":
-        init_openai()
-    elif model_provider == "ollama":
-        init_ollama()
-    elif model_provider == "anthropic":
-        init_anthropic()
-    elif model_provider == "gemini":
-        init_gemini()
-    else:
-        raise ValueError(f"Invalid model provider: {model_provider}")
+    match model_provider:
+        case "openai":
+            init_openai()
+        case "ollama":
+            init_ollama()
+        case "anthropic":
+            init_anthropic()
+        case "gemini":
+            init_gemini()
+        case "azure-openai":
+            init_azure_openai()
+        case _:
+            raise ValueError(f"Invalid model provider: {model_provider}")
     Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
     Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
 
@@ -52,6 +55,34 @@ def init_openai():
     Settings.embed_model = OpenAIEmbedding(**config)
 
 
+def init_azure_openai():
+    from llama_index.llms.azure_openai import AzureOpenAI
+    from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
+    from llama_index.core.constants import DEFAULT_TEMPERATURE
+
+    llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT")
+    embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
+    max_tokens = os.getenv("LLM_MAX_TOKENS")
+    api_key = os.getenv("AZURE_OPENAI_API_KEY")
+    llm_config = {
+        "api_key": api_key,
+        "deployment_name": llm_deployment,
+        "model": os.getenv("MODEL"),
+        "temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
+        "max_tokens": int(max_tokens) if max_tokens is not None else None,
+    }
+    Settings.llm = AzureOpenAI(**llm_config)
+
+    dimensions = os.getenv("EMBEDDING_DIM")
+    embedding_config = {
+        "api_key": api_key,
+        "deployment_name": embedding_deployment,
+        "model": os.getenv("EMBEDDING_MODEL"),
+        "dimensions": int(dimensions) if dimensions is not None else None,
+    }
+    Settings.embed_model = AzureOpenAIEmbedding(**embedding_config)
+
+
 def init_anthropic():
     from llama_index.llms.anthropic import Anthropic
     from llama_index.embeddings.huggingface import HuggingFaceEmbedding
-- 
GitLab