From 5fe2d519d2f384edf2b888be2498ff6eefddc611 Mon Sep 17 00:00:00 2001 From: Huu Le <39040748+leehuwuj@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:14:21 +0700 Subject: [PATCH] chore: Add Azure OpenAI model provider python (#110) --- .../types/streaming/fastapi/app/settings.py | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py index 90cc0918..0b45f4df 100644 --- a/templates/types/streaming/fastapi/app/settings.py +++ b/templates/types/streaming/fastapi/app/settings.py @@ -5,16 +5,19 @@ from llama_index.core.settings import Settings def init_settings(): model_provider = os.getenv("MODEL_PROVIDER") - if model_provider == "openai": - init_openai() - elif model_provider == "ollama": - init_ollama() - elif model_provider == "anthropic": - init_anthropic() - elif model_provider == "gemini": - init_gemini() - else: - raise ValueError(f"Invalid model provider: {model_provider}") + match model_provider: + case "openai": + init_openai() + case "ollama": + init_ollama() + case "anthropic": + init_anthropic() + case "gemini": + init_gemini() + case "azure-openai": + init_azure_openai() + case _: + raise ValueError(f"Invalid model provider: {model_provider}") Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024")) Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20")) @@ -52,6 +55,34 @@ def init_openai(): Settings.embed_model = OpenAIEmbedding(**config) +def init_azure_openai(): + from llama_index.llms.azure_openai import AzureOpenAI + from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding + from llama_index.core.constants import DEFAULT_TEMPERATURE + + llm_deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT") + embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") + max_tokens = os.getenv("LLM_MAX_TOKENS") + api_key = os.getenv("AZURE_OPENAI_API_KEY") + llm_config = { + "api_key": api_key, + "deployment_name": llm_deployment, + "model": os.getenv("MODEL"), + "temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), + "max_tokens": int(max_tokens) if max_tokens is not None else None, + } + Settings.llm = AzureOpenAI(**llm_config) + + dimensions = os.getenv("EMBEDDING_DIM") + embedding_config = { + "api_key": api_key, + "deployment_name": embedding_deployment, + "model": os.getenv("EMBEDDING_MODEL"), + "dimensions": int(dimensions) if dimensions is not None else None, + } + Settings.embed_model = AzureOpenAIEmbedding(**embedding_config) + + def init_anthropic(): from llama_index.llms.anthropic import Anthropic from llama_index.embeddings.huggingface import HuggingFaceEmbedding -- GitLab