Skip to content
Snippets Groups Projects
Commit 742a5221 authored by Taylor's avatar Taylor
Browse files

Renamed zure to azure_openai

Added non-API key authentication options
Added custom HTTP client options
Updated docstring
Updated embeddings creation to use deployment name (as Azure OpenAI takes deployment name, not model name)
parent 4d38b803
No related branches found
No related tags found
No related merge requests found
import os import os
from asyncio import sleep as asleep from asyncio import sleep as asleep
from time import sleep from time import sleep
from typing import List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import httpx
import openai import openai
from openai import OpenAIError from openai import OpenAIError
from openai._types import NotGiven from openai._types import NotGiven
...@@ -24,100 +25,135 @@ class AzureOpenAIEncoder(DenseEncoder): ...@@ -24,100 +25,135 @@ class AzureOpenAIEncoder(DenseEncoder):
async_client: Optional[openai.AsyncAzureOpenAI] = None async_client: Optional[openai.AsyncAzureOpenAI] = None
dimensions: Union[int, NotGiven] = NotGiven() dimensions: Union[int, NotGiven] = NotGiven()
type: str = "azure" type: str = "azure"
api_key: Optional[str] = None deployment_name: str | None = None
deployment_name: Optional[str] = None
azure_endpoint: Optional[str] = None
api_version: Optional[str] = None
model: Optional[str] = None
max_retries: int = 3 max_retries: int = 3
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, name: Optional[str] = None,
deployment_name: Optional[str] = None, azure_endpoint: str | None = None,
azure_endpoint: Optional[str] = None, api_version: str | None = None,
api_version: Optional[str] = None, api_key: str | None = None,
model: Optional[str] = None, # TODO we should change to `name` JB azure_ad_token: str | None = None,
azure_ad_token_provider: Callable[[], str] | None = None,
http_client_options: Optional[Dict[str, Any]] = None,
deployment_name: str = EncoderDefault.AZURE.value["deployment_name"],
score_threshold: float = 0.82, score_threshold: float = 0.82,
dimensions: Union[int, NotGiven] = NotGiven(), dimensions: Union[int, NotGiven] = NotGiven(),
max_retries: int = 3, max_retries: int = 3,
): ):
"""Initialize the AzureOpenAIEncoder. """Initialize the AzureOpenAIEncoder.
:param api_key: The API key for the Azure OpenAI API.
:type api_key: str
:param deployment_name: The name of the deployment to use.
:type deployment_name: str
:param azure_endpoint: The endpoint for the Azure OpenAI API. :param azure_endpoint: The endpoint for the Azure OpenAI API.
:type azure_endpoint: str Example: ``https://accountname.openai.azure.com``
:type azure_endpoint: str, optional
:param api_version: The version of the API to use. :param api_version: The version of the API to use.
:type api_version: str Example: ``"2025-02-01-preview"``
:param model: The model to use. :type api_version: str, optional
:type model: str
:param score_threshold: The score threshold for the embeddings. :param api_key: The API key for the Azure OpenAI API.
:type score_threshold: float :type api_key: str, optional
:param dimensions: The dimensions of the embeddings.
:type dimensions: int :param azure_ad_token: The Azure AD/Entra ID token for authentication.
:param max_retries: The maximum number of retries for the API call. https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:type max_retries: int :type azure_ad_token: str, optional
:param azure_ad_token_provider: A callable function that returns an Azure AD/Entra ID token.
:type azure_ad_token_provider: Callable[[], str], optional
:param http_client_options: Dictionary of options to configure httpx client
Example:
{
"proxies": "http://proxy.server:8080",
"timeout": 20.0,
"headers": {"Authorization": "Bearer xyz"}
}
:type http_client_options: Dict[str, Any], optional
:param deployment_name: The name of the model deployment to use.
:type deployment_name: str, optional
:param score_threshold: The score threshold for filtering embeddings.
Default is ``0.82``.
:type score_threshold: float, optional
:param dimensions: The number of dimensions for the embeddings. If not given, it defaults to the model's default setting.
:type dimensions: int, optional
:param max_retries: The maximum number of retries for API calls in case of failures.
Default is ``3``.
:type max_retries: int, optional
""" """
name = deployment_name
if name is None: if name is None:
name = EncoderDefault.AZURE.value["embedding_model"] name = deployment_name
if name is None:
name = EncoderDefault.AZURE.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold) super().__init__(name=name, score_threshold=score_threshold)
self.api_key = api_key
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
if not azure_endpoint:
raise ValueError("No Azure OpenAI endpoint provided.")
api_version = api_version or os.getenv("AZURE_OPENAI_API_VERSION")
if not api_version:
raise ValueError("No Azure OpenAI API version provided.")
if not (
azure_ad_token
or azure_ad_token_provider
or api_key
or os.getenv("AZURE_OPENAI_API_KEY")
):
raise ValueError(
"No authentication method provided. Please provide either `azure_ad_token`, "
"`azure_ad_token_provider`, or `api_key`."
)
# Only check API Key if no AD token or provider is used
if not azure_ad_token and not azure_ad_token_provider:
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
if not api_key:
raise ValueError("No Azure OpenAI API key provided.")
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.azure_endpoint = azure_endpoint
self.api_version = api_version
self.model = model
# set dimensions to support openai embed 3 dimensions param # set dimensions to support openai embed 3 dimensions param
self.dimensions = dimensions self.dimensions = dimensions
if self.api_key is None:
self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
if self.api_key is None:
raise ValueError("No Azure OpenAI API key provided.")
if max_retries is not None: if max_retries is not None:
self.max_retries = max_retries self.max_retries = max_retries
if self.deployment_name is None:
self.deployment_name = EncoderDefault.AZURE.value["deployment_name"] # Only create HTTP clients if options are provided
# deployment_name may still be None, but it is optional in the API sync_http_client = (
if self.azure_endpoint is None: httpx.Client(**http_client_options) if http_client_options else None
self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") )
if self.azure_endpoint is None: async_http_client = (
raise ValueError("No Azure OpenAI endpoint provided.") httpx.AsyncClient(**http_client_options) if http_client_options else None
if self.api_version is None:
self.api_version = os.getenv("AZURE_OPENAI_API_VERSION")
if self.api_version is None:
raise ValueError("No Azure OpenAI API version provided.")
if self.model is None:
self.model = os.getenv("AZURE_OPENAI_MODEL")
if self.model is None:
raise ValueError("No Azure OpenAI model provided.")
assert (
self.api_key is not None
and self.azure_endpoint is not None
and self.api_version is not None
and self.model is not None
) )
assert azure_endpoint is not None and self.deployment_name is not None
try: try:
self.client = openai.AzureOpenAI( self.client = openai.AzureOpenAI(
azure_deployment=( azure_endpoint=azure_endpoint,
str(self.deployment_name) if self.deployment_name else None api_version=api_version,
), api_key=api_key,
api_key=str(self.api_key), azure_ad_token=azure_ad_token,
azure_endpoint=str(self.azure_endpoint), azure_ad_token_provider=azure_ad_token_provider,
api_version=str(self.api_version), http_client=sync_http_client,
) )
self.async_client = openai.AsyncAzureOpenAI( self.async_client = openai.AsyncAzureOpenAI(
azure_deployment=( azure_endpoint=azure_endpoint,
str(self.deployment_name) if self.deployment_name else None api_version=api_version,
), api_key=api_key,
api_key=str(self.api_key), azure_ad_token=azure_ad_token,
azure_endpoint=str(self.azure_endpoint), azure_ad_token_provider=azure_ad_token_provider,
api_version=str(self.api_version), http_client=async_http_client,
) )
except Exception as e: except Exception as e:
logger.error("OpenAI API client failed to initialize. Error: %s", e)
raise ValueError( raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}" f"OpenAI API client failed to initialize. Error: {e}"
) from e ) from e
...@@ -139,7 +175,7 @@ class AzureOpenAIEncoder(DenseEncoder): ...@@ -139,7 +175,7 @@ class AzureOpenAIEncoder(DenseEncoder):
try: try:
embeds = self.client.embeddings.create( embeds = self.client.embeddings.create(
input=docs, input=docs,
model=str(self.model), model=str(self.deployment_name),
dimensions=self.dimensions, dimensions=self.dimensions,
) )
if embeds.data: if embeds.data:
...@@ -149,12 +185,12 @@ class AzureOpenAIEncoder(DenseEncoder): ...@@ -149,12 +185,12 @@ class AzureOpenAIEncoder(DenseEncoder):
if self.max_retries != 0 and j < self.max_retries: if self.max_retries != 0 and j < self.max_retries:
sleep(2**j) sleep(2**j)
logger.warning( logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}" "Retrying in %d seconds due to OpenAIError: %s", 2**j, e
) )
else: else:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {e}") logger.error("Azure OpenAI API call failed. Error: %s", e)
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
if ( if (
...@@ -183,23 +219,22 @@ class AzureOpenAIEncoder(DenseEncoder): ...@@ -183,23 +219,22 @@ class AzureOpenAIEncoder(DenseEncoder):
try: try:
embeds = await self.async_client.embeddings.create( embeds = await self.async_client.embeddings.create(
input=docs, input=docs,
model=str(self.model), model=str(self.deployment_name),
dimensions=self.dimensions, dimensions=self.dimensions,
) )
if embeds.data: if embeds.data:
break break
except OpenAIError as e: except OpenAIError as e:
logger.error("Exception occurred", exc_info=True) logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0 and j < self.max_retries: if self.max_retries != 0 and j < self.max_retries:
await asleep(2**j) await asleep(2**j)
logger.warning( logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}" "Retrying in %d seconds due to OpenAIError: %s", 2**j, e
) )
else: else:
raise raise
except Exception as e: except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {e}") logger.error("Azure OpenAI API call failed. Error: %s", e)
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
if ( if (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment