From cb16cf629d4388acb779337d92a39de30e2aee05 Mon Sep 17 00:00:00 2001 From: Martim Santos <72747170+martimfasantos@users.noreply.github.com> Date: Mon, 3 Mar 2025 20:53:49 +0000 Subject: [PATCH] Add base_url to AzureOpenAI (#17996) --- .../llama_index/llms/azure_openai/base.py | 13 ++++++++++++- .../llama-index-llms-azure-openai/pyproject.toml | 2 +- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py index efef001ca..332a13eba 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/llama_index/llms/azure_openai/base.py @@ -11,6 +11,7 @@ from llama_index.llms.azure_openai.utils import ( resolve_from_aliases, ) from llama_index.llms.openai import OpenAI +from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE from openai import AsyncAzureOpenAI from openai import AzureOpenAI as SyncAzureOpenAI from openai.lib.azure import AzureADTokenProvider @@ -92,10 +93,13 @@ class AzureOpenAI(OpenAI): use_azure_ad: bool = Field( description="Indicates if Microsoft Entra ID (former Azure AD) is used for token authentication" ) - azure_ad_token_provider: Optional[AzureADTokenProvider] = Field( default=None, description="Callback function to provide Azure Entra ID token." ) + api_base: Optional[str] = Field( + default=None, + description="The Azure Base URL to use. Useful for proxies on top of Azure OpenAI.", + ) _azure_ad_token: Any = PrivateAttr(default=None) _client: SyncAzureOpenAI = PrivateAttr() @@ -113,6 +117,7 @@ class AzureOpenAI(OpenAI): reuse_client: bool = True, api_key: Optional[str] = None, api_version: Optional[str] = None, + api_base: Optional[str] = None, # azure specific azure_endpoint: Optional[str] = None, azure_deployment: Optional[str] = None, @@ -157,6 +162,7 @@ class AzureOpenAI(OpenAI): api_key=api_key, azure_endpoint=azure_endpoint, azure_deployment=azure_deployment, + api_base=api_base, azure_ad_token_provider=azure_ad_token_provider, use_azure_ad=use_azure_ad, api_version=api_version, @@ -171,6 +177,10 @@ class AzureOpenAI(OpenAI): **kwargs, ) + # reset api_base to None if it is the default + if self.api_base == DEFAULT_OPENAI_API_BASE: + self.api_base = None + @model_validator(mode="before") def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate necessary credentials are set.""" @@ -235,6 +245,7 @@ class AzureOpenAI(OpenAI): "timeout": self.timeout, "azure_endpoint": self.azure_endpoint, "azure_deployment": self.azure_deployment, + "base_url": self.api_base, "azure_ad_token_provider": self.azure_ad_token_provider, "api_version": self.api_version, "default_headers": self.default_headers, diff --git a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml index dd5ea6cfb..931c554be 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-azure-openai/pyproject.toml @@ -29,7 +29,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-azure-openai" readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = ">=3.9,<4.0" -- GitLab