From cb16cf629d4388acb779337d92a39de30e2aee05 Mon Sep 17 00:00:00 2001
From: Martim Santos <72747170+martimfasantos@users.noreply.github.com>
Date: Mon, 3 Mar 2025 20:53:49 +0000
Subject: [PATCH] Add base_url to AzureOpenAI (#17996)

---
 .../llama_index/llms/azure_openai/base.py           | 13 ++++++++++++-
 .../llama-index-llms-azure-openai/pyproject.toml    |  2 +-
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py
index efef001ca..332a13eba 100644
--- a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py
@@ -11,6 +11,7 @@ from llama_index.llms.azure_openai.utils import (
     resolve_from_aliases,
 )
 from llama_index.llms.openai import OpenAI
+from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE
 from openai import AsyncAzureOpenAI
 from openai import AzureOpenAI as SyncAzureOpenAI
 from openai.lib.azure import AzureADTokenProvider
@@ -92,10 +93,13 @@ class AzureOpenAI(OpenAI):
     use_azure_ad: bool = Field(
         description="Indicates if Microsoft Entra ID (former Azure AD) is used for token authentication"
     )
-
     azure_ad_token_provider: Optional[AzureADTokenProvider] = Field(
         default=None, description="Callback function to provide Azure Entra ID token."
     )
+    api_base: Optional[str] = Field(
+        default=None,
+        description="The Azure Base URL to use. Useful for proxies on top of Azure OpenAI.",
+    )
 
     _azure_ad_token: Any = PrivateAttr(default=None)
     _client: SyncAzureOpenAI = PrivateAttr()
@@ -113,6 +117,7 @@ class AzureOpenAI(OpenAI):
         reuse_client: bool = True,
         api_key: Optional[str] = None,
         api_version: Optional[str] = None,
+        api_base: Optional[str] = None,
         # azure specific
         azure_endpoint: Optional[str] = None,
         azure_deployment: Optional[str] = None,
@@ -157,6 +162,7 @@ class AzureOpenAI(OpenAI):
             api_key=api_key,
             azure_endpoint=azure_endpoint,
             azure_deployment=azure_deployment,
+            api_base=api_base,
             azure_ad_token_provider=azure_ad_token_provider,
             use_azure_ad=use_azure_ad,
             api_version=api_version,
@@ -171,6 +177,10 @@ class AzureOpenAI(OpenAI):
             **kwargs,
         )
 
+        # reset api_base to None if it is the default
+        if self.api_base == DEFAULT_OPENAI_API_BASE:
+            self.api_base = None
+
     @model_validator(mode="before")
     def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
         """Validate necessary credentials are set."""
@@ -235,6 +245,7 @@ class AzureOpenAI(OpenAI):
             "timeout": self.timeout,
             "azure_endpoint": self.azure_endpoint,
             "azure_deployment": self.azure_deployment,
+            "base_url": self.api_base,
             "azure_ad_token_provider": self.azure_ad_token_provider,
             "api_version": self.api_version,
             "default_headers": self.default_headers,
diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml
index dd5ea6cfb..931c554be 100644
--- a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml
@@ -29,7 +29,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-azure-openai"
 readme = "README.md"
-version = "0.3.0"
+version = "0.3.1"
 
 [tool.poetry.dependencies]
 python = ">=3.9,<4.0"
-- 
GitLab