From f69f4d17ac8c3dcacd7bf66c2f63aa0e9c9cf800 Mon Sep 17 00:00:00 2001
From: San Nguyen <vinhsannguyen91@gmail.com>
Date: Sun, 5 May 2024 22:45:46 +0900
Subject: [PATCH] fix: azure openai api key not picking up env var (#13267)

---
 .../llama_index/embeddings/azure_openai/base.py      |  4 +++-
 .../pyproject.toml                                   |  2 +-
 .../tests/test_azure_openai.py                       |  2 +-
 .../llama_index/llms/azure_openai/base.py            | 12 +++++++++++-
 .../llama-index-llms-azure-openai/pyproject.toml     |  2 +-
 .../tests/test_azure_openai.py                       |  4 +++-
 6 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/llama_index/embeddings/azure_openai/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/llama_index/embeddings/azure_openai/base.py
index 54dee30a68..5461847597 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/llama_index/embeddings/azure_openai/base.py
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/llama_index/embeddings/azure_openai/base.py
@@ -46,7 +46,7 @@ class AzureOpenAIEmbedding(OpenAIEmbedding):
         # azure specific
         azure_endpoint: Optional[str] = None,
         azure_deployment: Optional[str] = None,
-        azure_ad_token_provider: AzureADTokenProvider = None,
+        azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
         deployment_name: Optional[str] = None,
         max_retries: int = 10,
         reuse_client: bool = True,
@@ -60,6 +60,8 @@ class AzureOpenAIEmbedding(OpenAIEmbedding):
             "azure_endpoint", azure_endpoint, "AZURE_OPENAI_ENDPOINT", ""
         )
 
+        api_key = get_from_param_or_env("api_key", api_key, "AZURE_OPENAI_API_KEY")
+
         azure_deployment = resolve_from_aliases(
             azure_deployment,
             deployment_name,
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/pyproject.toml
index f734e8983a..4b631b838d 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/pyproject.toml
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-embeddings-azure-openai"
 readme = "README.md"
-version = "0.1.8"
+version = "0.1.9"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/tests/test_azure_openai.py b/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/tests/test_azure_openai.py
index a4ffcbf450..38fc4bf79e 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/tests/test_azure_openai.py
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-azure-openai/tests/test_azure_openai.py
@@ -11,7 +11,7 @@ def test_custom_http_client(azure_openai_mock: MagicMock) -> None:
     Should get passed on to the implementation from OpenAI.
     """
     custom_http_client = httpx.Client()
-    embedding = AzureOpenAIEmbedding(http_client=custom_http_client)
+    embedding = AzureOpenAIEmbedding(http_client=custom_http_client, api_key="mock")
     embedding._get_client()
     azure_openai_mock.assert_called()
     kwargs = azure_openai_mock.call_args.kwargs
diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py
index 35ef8e9a63..481e80a348 100644
--- a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py
@@ -93,7 +93,7 @@ class AzureOpenAI(OpenAI):
         # azure specific
         azure_endpoint: Optional[str] = None,
         azure_deployment: Optional[str] = None,
-        azure_ad_token_provider: AzureADTokenProvider = None,
+        azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
         use_azure_ad: bool = False,
         callback_manager: Optional[CallbackManager] = None,
         # aliases for engine
@@ -186,6 +186,16 @@ class AzureOpenAI(OpenAI):
         if self.use_azure_ad:
             self._azure_ad_token = refresh_openai_azuread_token(self._azure_ad_token)
             self.api_key = self._azure_ad_token.token
+        else:
+            import os
+
+            self.api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
+
+        if self.api_key is None:
+            raise ValueError(
+                "You must set an `api_key` parameter. "
+                "Alternatively, you can set the AZURE_OPENAI_API_KEY env var OR set `use_azure_ad=True`."
+            )
 
         return {
             "api_key": self.api_key,
diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml
index 3e2242f309..0924c6a43c 100644
--- a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml
@@ -29,7 +29,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-azure-openai"
 readme = "README.md"
-version = "0.1.6"
+version = "0.1.7"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/tests/test_azure_openai.py b/llama-index-integrations/llms/llama-index-llms-azure-openai/tests/test_azure_openai.py
index 7529bad7eb..9e66ee8f42 100644
--- a/llama-index-integrations/llms/llama-index-llms-azure-openai/tests/test_azure_openai.py
+++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/tests/test_azure_openai.py
@@ -40,7 +40,9 @@ def test_custom_http_client(sync_azure_openai_mock: MagicMock) -> None:
     mock_instance = sync_azure_openai_mock.return_value
     # Valid mocked result required to not run into another error
     mock_instance.chat.completions.create.return_value = mock_chat_completion_v1()
-    azure_openai = AzureOpenAI(engine="foo bar", http_client=custom_http_client)
+    azure_openai = AzureOpenAI(
+        engine="foo bar", http_client=custom_http_client, api_key="mock"
+    )
     azure_openai.complete("test prompt")
     sync_azure_openai_mock.assert_called()
     kwargs = sync_azure_openai_mock.call_args.kwargs
-- 
GitLab