Skip to content
Snippets Groups Projects
Commit 152796db authored by Smart2332's avatar Smart2332
Browse files

Added dimension param to Azure Encoder

parent a8544d8f
No related branches found
No related tags found
No related merge requests found
import os
from typing import List, Optional
from time import sleep
from typing import List, Optional, Union
import openai
from openai._types import NotGiven
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse
from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger
class AzureOpenAILLM(BaseLLM):
client: Optional[openai.AzureOpenAI]
temperature: Optional[float]
max_tokens: Optional[int]
class AzureOpenAIEncoder(BaseEncoder):
client: Optional[openai.AzureOpenAI] = None
dimensions: Union[int, NotGiven] = NotGiven()
type: str = "azure"
api_key: Optional[str] = None
deployment_name: Optional[str] = None
azure_endpoint: Optional[str] = None
api_version: Optional[str] = None
model: Optional[str] = None
def __init__(
self,
name: Optional[str] = None,
openai_api_key: Optional[str] = None,
api_key: Optional[str] = None,
deployment_name: Optional[str] = None,
azure_endpoint: Optional[str] = None,
temperature: float = 0.01,
max_tokens: int = 200,
api_version="2023-07-01-preview",
api_version: Optional[str] = None,
model: Optional[str] = None, # TODO we should change to `name` JB
score_threshold: float = 0.82,
dimensions: Union[int, NotGiven] = NotGiven(),
):
name = deployment_name
if name is None:
name = EncoderDefault.AZURE.value["language_model"]
super().__init__(name=name)
api_key = openai_api_key or os.getenv("AZURE_OPENAI_API_KEY")
if api_key is None:
raise ValueError("AzureOpenAI API key cannot be 'None'.")
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError("Azure endpoint API key cannot be 'None'.")
name = EncoderDefault.AZURE.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
self.api_key = api_key
self.deployment_name = deployment_name
self.azure_endpoint = azure_endpoint
self.api_version = api_version
self.model = model
# set dimensions to support openai embed 3 dimensions param
self.dimensions = dimensions
if self.api_key is None:
self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
if self.api_key is None:
raise ValueError("No Azure OpenAI API key provided.")
if self.deployment_name is None:
self.deployment_name = EncoderDefault.AZURE.value["deployment_name"]
# deployment_name may still be None, but it is optional in the API
if self.azure_endpoint is None:
self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
if self.azure_endpoint is None:
raise ValueError("No Azure OpenAI endpoint provided.")
if self.api_version is None:
self.api_version = os.getenv("AZURE_OPENAI_API_VERSION")
if self.api_version is None:
raise ValueError("No Azure OpenAI API version provided.")
if self.model is None:
self.model = os.getenv("AZURE_OPENAI_MODEL")
if self.model is None:
raise ValueError("No Azure OpenAI model provided.")
assert (
self.api_key is not None
and self.azure_endpoint is not None
and self.api_version is not None
and self.model is not None
)
try:
self.client = openai.AzureOpenAI(
api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version
azure_deployment=(
str(self.deployment_name) if self.deployment_name else None
),
api_key=str(self.api_key),
azure_endpoint=str(self.azure_endpoint),
api_version=str(self.api_version),
# _strict_response_validation=True,
)
except Exception as e:
raise ValueError(f"AzureOpenAI API client failed to initialize. Error: {e}")
self.temperature = temperature
self.max_tokens = max_tokens
raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}"
) from e
def __call__(self, messages: List[Message]) -> str:
def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
raise ValueError("AzureOpenAI client is not initialized.")
try:
completion = self.client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
raise ValueError("OpenAI client is not initialized.")
embeds = None
error_message = ""
output = completion.choices[0].message.content
# Exponential backoff
for j in range(3):
try:
embeds = self.client.embeddings.create(
input=docs, model=str(self.model), dimensions=self.dimensions,
)
if embeds.data:
break
except OpenAIError as e:
# print full traceback
import traceback
if not output:
raise Exception("No output generated")
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e
traceback.print_exc()
sleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment