From 742a52214bd233965da91e7e831db6e149bf9348 Mon Sep 17 00:00:00 2001
From: Taylor <TaylorN15@users.noreply.github.com>
Date: Tue, 11 Mar 2025 02:25:12 +0000
Subject: [PATCH] Renamed zure to azure_openai Added non-API key authentication
 options Added custom HTTP client options Updated docstring Updated embeddings
 creation to use deployment name (as Azure OpenAI takes deployment name, not
 model name)

---
 .../encoders/{zure.py => azure_openai.py}     | 183 +++++++++++-------
 1 file changed, 109 insertions(+), 74 deletions(-)
 rename semantic_router/encoders/{zure.py => azure_openai.py} (51%)

diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/azure_openai.py
similarity index 51%
rename from semantic_router/encoders/zure.py
rename to semantic_router/encoders/azure_openai.py
index faab1c90..2ca45ada 100644
--- a/semantic_router/encoders/zure.py
+++ b/semantic_router/encoders/azure_openai.py
@@ -1,8 +1,9 @@
 import os
 from asyncio import sleep as asleep
 from time import sleep
-from typing import List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
 
+import httpx
 import openai
 from openai import OpenAIError
 from openai._types import NotGiven
@@ -24,100 +25,135 @@ class AzureOpenAIEncoder(DenseEncoder):
     async_client: Optional[openai.AsyncAzureOpenAI] = None
     dimensions: Union[int, NotGiven] = NotGiven()
     type: str = "azure"
-    api_key: Optional[str] = None
-    deployment_name: Optional[str] = None
-    azure_endpoint: Optional[str] = None
-    api_version: Optional[str] = None
-    model: Optional[str] = None
+    deployment_name: str | None = None
     max_retries: int = 3
 
     def __init__(
         self,
-        api_key: Optional[str] = None,
-        deployment_name: Optional[str] = None,
-        azure_endpoint: Optional[str] = None,
-        api_version: Optional[str] = None,
-        model: Optional[str] = None,  # TODO we should change to `name` JB
+        name: Optional[str] = None,
+        azure_endpoint: str | None = None,
+        api_version: str | None = None,
+        api_key: str | None = None,
+        azure_ad_token: str | None = None,
+        azure_ad_token_provider: Callable[[], str] | None = None,
+        http_client_options: Optional[Dict[str, Any]] = None,
+        deployment_name: str = EncoderDefault.AZURE.value["deployment_name"],
         score_threshold: float = 0.82,
         dimensions: Union[int, NotGiven] = NotGiven(),
         max_retries: int = 3,
     ):
         """Initialize the AzureOpenAIEncoder.
 
-        :param api_key: The API key for the Azure OpenAI API.
-        :type api_key: str
-        :param deployment_name: The name of the deployment to use.
-        :type deployment_name: str
         :param azure_endpoint: The endpoint for the Azure OpenAI API.
-        :type azure_endpoint: str
+            Example: ``https://accountname.openai.azure.com``
+        :type azure_endpoint: str, optional
+
         :param api_version: The version of the API to use.
-        :type api_version: str
-        :param model: The model to use.
-        :type model: str
-        :param score_threshold: The score threshold for the embeddings.
-        :type score_threshold: float
-        :param dimensions: The dimensions of the embeddings.
-        :type dimensions: int
-        :param max_retries: The maximum number of retries for the API call.
-        :type max_retries: int
+            Example: ``"2025-02-01-preview"``
+        :type api_version: str, optional
+
+        :param api_key: The API key for the Azure OpenAI API.
+        :type api_key: str, optional
+
+        :param azure_ad_token: The Azure AD/Entra ID token for authentication.
+            https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
+        :type azure_ad_token: str, optional
+
+        :param azure_ad_token_provider: A callable function that returns an Azure AD/Entra ID token.
+        :type azure_ad_token_provider: Callable[[], str], optional
+
+        :param http_client_options: Dictionary of options to configure httpx client
+            Example:
+                {
+                    "proxies": "http://proxy.server:8080",
+                    "timeout": 20.0,
+                    "headers": {"Authorization": "Bearer xyz"}
+                }
+        :type http_client_options: Dict[str, Any], optional
+
+        :param deployment_name: The name of the model deployment to use.
+        :type deployment_name: str, optional
+
+        :param score_threshold: The score threshold for filtering embeddings.
+            Default is ``0.82``.
+        :type score_threshold: float, optional
+
+        :param dimensions: The number of dimensions for the embeddings. If not given, it defaults to the model's default setting.
+        :type dimensions: int, optional
+
+        :param max_retries: The maximum number of retries for API calls in case of failures.
+            Default is ``3``.
+        :type max_retries: int, optional
         """
-        name = deployment_name
         if name is None:
-            name = EncoderDefault.AZURE.value["embedding_model"]
+            name = deployment_name
+            if name is None:
+                name = EncoderDefault.AZURE.value["embedding_model"]
         super().__init__(name=name, score_threshold=score_threshold)
-        self.api_key = api_key
+
+        azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
+        if not azure_endpoint:
+            raise ValueError("No Azure OpenAI endpoint provided.")
+
+        api_version = api_version or os.getenv("AZURE_OPENAI_API_VERSION")
+        if not api_version:
+            raise ValueError("No Azure OpenAI API version provided.")
+
+        if not (
+            azure_ad_token
+            or azure_ad_token_provider
+            or api_key
+            or os.getenv("AZURE_OPENAI_API_KEY")
+        ):
+            raise ValueError(
+                "No authentication method provided. Please provide either `azure_ad_token`, "
+                "`azure_ad_token_provider`, or `api_key`."
+            )
+
+        # Only check API Key if no AD token or provider is used
+        if not azure_ad_token and not azure_ad_token_provider:
+            api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
+            if not api_key:
+                raise ValueError("No Azure OpenAI API key provided.")
+
         self.deployment_name = deployment_name
-        self.azure_endpoint = azure_endpoint
-        self.api_version = api_version
-        self.model = model
+
         # set dimensions to support openai embed 3 dimensions param
         self.dimensions = dimensions
-        if self.api_key is None:
-            self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
-            if self.api_key is None:
-                raise ValueError("No Azure OpenAI API key provided.")
+
         if max_retries is not None:
             self.max_retries = max_retries
-        if self.deployment_name is None:
-            self.deployment_name = EncoderDefault.AZURE.value["deployment_name"]
-        # deployment_name may still be None, but it is optional in the API
-        if self.azure_endpoint is None:
-            self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
-            if self.azure_endpoint is None:
-                raise ValueError("No Azure OpenAI endpoint provided.")
-        if self.api_version is None:
-            self.api_version = os.getenv("AZURE_OPENAI_API_VERSION")
-            if self.api_version is None:
-                raise ValueError("No Azure OpenAI API version provided.")
-        if self.model is None:
-            self.model = os.getenv("AZURE_OPENAI_MODEL")
-            if self.model is None:
-                raise ValueError("No Azure OpenAI model provided.")
-        assert (
-            self.api_key is not None
-            and self.azure_endpoint is not None
-            and self.api_version is not None
-            and self.model is not None
+
+        # Only create HTTP clients if options are provided
+        sync_http_client = (
+            httpx.Client(**http_client_options) if http_client_options else None
+        )
+        async_http_client = (
+            httpx.AsyncClient(**http_client_options) if http_client_options else None
         )
 
+        assert azure_endpoint is not None and self.deployment_name is not None
+
         try:
             self.client = openai.AzureOpenAI(
-                azure_deployment=(
-                    str(self.deployment_name) if self.deployment_name else None
-                ),
-                api_key=str(self.api_key),
-                azure_endpoint=str(self.azure_endpoint),
-                api_version=str(self.api_version),
+                azure_endpoint=azure_endpoint,
+                api_version=api_version,
+                api_key=api_key,
+                azure_ad_token=azure_ad_token,
+                azure_ad_token_provider=azure_ad_token_provider,
+                http_client=sync_http_client,
             )
             self.async_client = openai.AsyncAzureOpenAI(
-                azure_deployment=(
-                    str(self.deployment_name) if self.deployment_name else None
-                ),
-                api_key=str(self.api_key),
-                azure_endpoint=str(self.azure_endpoint),
-                api_version=str(self.api_version),
+                azure_endpoint=azure_endpoint,
+                api_version=api_version,
+                api_key=api_key,
+                azure_ad_token=azure_ad_token,
+                azure_ad_token_provider=azure_ad_token_provider,
+                http_client=async_http_client,
             )
+
         except Exception as e:
+            logger.error("OpenAI API client failed to initialize. Error: %s", e)
             raise ValueError(
                 f"OpenAI API client failed to initialize. Error: {e}"
             ) from e
@@ -139,7 +175,7 @@ class AzureOpenAIEncoder(DenseEncoder):
             try:
                 embeds = self.client.embeddings.create(
                     input=docs,
-                    model=str(self.model),
+                    model=str(self.deployment_name),
                     dimensions=self.dimensions,
                 )
                 if embeds.data:
@@ -149,12 +185,12 @@ class AzureOpenAIEncoder(DenseEncoder):
                 if self.max_retries != 0 and j < self.max_retries:
                     sleep(2**j)
                     logger.warning(
-                        f"Retrying in {2**j} seconds due to OpenAIError: {e}"
+                        "Retrying in %d seconds due to OpenAIError: %s", 2**j, e
                     )
                 else:
                     raise
             except Exception as e:
-                logger.error(f"Azure OpenAI API call failed. Error: {e}")
+                logger.error("Azure OpenAI API call failed. Error: %s", e)
                 raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
 
         if (
@@ -183,23 +219,22 @@ class AzureOpenAIEncoder(DenseEncoder):
             try:
                 embeds = await self.async_client.embeddings.create(
                     input=docs,
-                    model=str(self.model),
+                    model=str(self.deployment_name),
                     dimensions=self.dimensions,
                 )
                 if embeds.data:
                     break
-
             except OpenAIError as e:
                 logger.error("Exception occurred", exc_info=True)
                 if self.max_retries != 0 and j < self.max_retries:
                     await asleep(2**j)
                     logger.warning(
-                        f"Retrying in {2**j} seconds due to OpenAIError: {e}"
+                        "Retrying in %d seconds due to OpenAIError: %s", 2**j, e
                     )
                 else:
                     raise
             except Exception as e:
-                logger.error(f"Azure OpenAI API call failed. Error: {e}")
+                logger.error("Azure OpenAI API call failed. Error: %s", e)
                 raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
 
         if (
-- 
GitLab