Skip to content
Snippets Groups Projects
Unverified Commit cb16cf62 authored by Martim Santos's avatar Martim Santos Committed by GitHub
Browse files

Add base_url to AzureOpenAI (#17996)

parent e8f73a4b
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ from llama_index.llms.azure_openai.utils import ( ...@@ -11,6 +11,7 @@ from llama_index.llms.azure_openai.utils import (
resolve_from_aliases, resolve_from_aliases,
) )
from llama_index.llms.openai import OpenAI from llama_index.llms.openai import OpenAI
from llama_index.llms.openai.utils import DEFAULT_OPENAI_API_BASE
from openai import AsyncAzureOpenAI from openai import AsyncAzureOpenAI
from openai import AzureOpenAI as SyncAzureOpenAI from openai import AzureOpenAI as SyncAzureOpenAI
from openai.lib.azure import AzureADTokenProvider from openai.lib.azure import AzureADTokenProvider
...@@ -92,10 +93,13 @@ class AzureOpenAI(OpenAI): ...@@ -92,10 +93,13 @@ class AzureOpenAI(OpenAI):
use_azure_ad: bool = Field( use_azure_ad: bool = Field(
description="Indicates if Microsoft Entra ID (former Azure AD) is used for token authentication" description="Indicates if Microsoft Entra ID (former Azure AD) is used for token authentication"
) )
azure_ad_token_provider: Optional[AzureADTokenProvider] = Field( azure_ad_token_provider: Optional[AzureADTokenProvider] = Field(
default=None, description="Callback function to provide Azure Entra ID token." default=None, description="Callback function to provide Azure Entra ID token."
) )
api_base: Optional[str] = Field(
default=None,
description="The Azure Base URL to use. Useful for proxies on top of Azure OpenAI.",
)
_azure_ad_token: Any = PrivateAttr(default=None) _azure_ad_token: Any = PrivateAttr(default=None)
_client: SyncAzureOpenAI = PrivateAttr() _client: SyncAzureOpenAI = PrivateAttr()
...@@ -113,6 +117,7 @@ class AzureOpenAI(OpenAI): ...@@ -113,6 +117,7 @@ class AzureOpenAI(OpenAI):
reuse_client: bool = True, reuse_client: bool = True,
api_key: Optional[str] = None, api_key: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
api_base: Optional[str] = None,
# azure specific # azure specific
azure_endpoint: Optional[str] = None, azure_endpoint: Optional[str] = None,
azure_deployment: Optional[str] = None, azure_deployment: Optional[str] = None,
...@@ -157,6 +162,7 @@ class AzureOpenAI(OpenAI): ...@@ -157,6 +162,7 @@ class AzureOpenAI(OpenAI):
api_key=api_key, api_key=api_key,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment, azure_deployment=azure_deployment,
api_base=api_base,
azure_ad_token_provider=azure_ad_token_provider, azure_ad_token_provider=azure_ad_token_provider,
use_azure_ad=use_azure_ad, use_azure_ad=use_azure_ad,
api_version=api_version, api_version=api_version,
...@@ -171,6 +177,10 @@ class AzureOpenAI(OpenAI): ...@@ -171,6 +177,10 @@ class AzureOpenAI(OpenAI):
**kwargs, **kwargs,
) )
# reset api_base to None if it is the default
if self.api_base == DEFAULT_OPENAI_API_BASE:
self.api_base = None
@model_validator(mode="before") @model_validator(mode="before")
def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate necessary credentials are set.""" """Validate necessary credentials are set."""
...@@ -235,6 +245,7 @@ class AzureOpenAI(OpenAI): ...@@ -235,6 +245,7 @@ class AzureOpenAI(OpenAI):
"timeout": self.timeout, "timeout": self.timeout,
"azure_endpoint": self.azure_endpoint, "azure_endpoint": self.azure_endpoint,
"azure_deployment": self.azure_deployment, "azure_deployment": self.azure_deployment,
"base_url": self.api_base,
"azure_ad_token_provider": self.azure_ad_token_provider, "azure_ad_token_provider": self.azure_ad_token_provider,
"api_version": self.api_version, "api_version": self.api_version,
"default_headers": self.default_headers, "default_headers": self.default_headers,
......
...@@ -29,7 +29,7 @@ exclude = ["**/BUILD"] ...@@ -29,7 +29,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-llms-azure-openai" name = "llama-index-llms-azure-openai"
readme = "README.md" readme = "README.md"
version = "0.3.0" version = "0.3.1"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.9,<4.0" python = ">=3.9,<4.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment