-
James McKeown authoredJames McKeown authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
zure.py 3.99 KiB
import os
from time import sleep
import openai
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger
class AzureOpenAIEncoder(BaseEncoder):
client: openai.AzureOpenAI | None = None
type: str = "azure"
api_key: str | None = None
deployment_name: str | None = None
azure_endpoint: str | None = None
api_version: str | None = None
model: str | None = None
def __init__(
self,
api_key: str | None = None,
deployment_name: str | None = None,
azure_endpoint: str | None = None,
api_version: str | None = None,
model: str | None = None,
):
name = deployment_name
if name is None:
name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002")
super().__init__(name=name)
self.api_key = api_key
self.deployment_name = deployment_name
self.azure_endpoint = azure_endpoint
self.api_version = api_version
self.model = model
if self.api_key is None:
self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
if self.api_key is None:
raise ValueError("No Azure OpenAI API key provided.")
if self.deployment_name is None:
self.deployment_name = os.getenv(
"AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
)
if self.deployment_name is None:
raise ValueError("No Azure OpenAI deployment name provided.")
if self.azure_endpoint is None:
self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
if self.azure_endpoint is None:
raise ValueError("No Azure OpenAI endpoint provided.")
if self.api_version is None:
self.api_version = os.getenv("AZURE_OPENAI_API_VERSION")
if self.api_version is None:
raise ValueError("No Azure OpenAI API version provided.")
if self.model is None:
self.model = os.getenv("AZURE_OPENAI_MODEL")
if self.model is None:
raise ValueError("No Azure OpenAI model provided.")
assert (
self.api_key is not None
and self.deployment_name is not None
and self.azure_endpoint is not None
and self.api_version is not None
and self.model is not None
)
try:
self.client = openai.AzureOpenAI(
azure_deployment=str(deployment_name),
api_key=str(api_key),
azure_endpoint=str(azure_endpoint),
api_version=str(api_version),
_strict_response_validation=True,
)
except Exception as e:
raise ValueError(f"OpenAI API client failed to initialize. Error: {e}")
def __call__(self, docs: list[str]) -> list[list[float]]:
if self.client is None:
raise ValueError("OpenAI client is not initialized.")
embeds = None
error_message = ""
# Exponential backoff
for j in range(3):
try:
embeds = self.client.embeddings.create(
input=docs, model=str(self.model)
)
if embeds.data:
break
except OpenAIError as e:
sleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}")
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings