-
James Briggs authoredJames Briggs authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
zure.py 3.09 KiB
import os
from typing import List, Optional
import openai
from pydantic import PrivateAttr
from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger
class AzureOpenAILLM(BaseLLM):
"""LLM for Azure OpenAI. Requires an Azure OpenAI API key.
"""
_client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None)
def __init__(
self,
name: Optional[str] = None,
openai_api_key: Optional[str] = None,
azure_endpoint: Optional[str] = None,
temperature: float = 0.01,
max_tokens: int = 200,
api_version="2023-07-01-preview",
):
"""Initialize the AzureOpenAILLM.
:param name: The name of the Azure OpenAI model to use.
:type name: Optional[str]
:param openai_api_key: The Azure OpenAI API key.
:type openai_api_key: Optional[str]
:param azure_endpoint: The Azure OpenAI endpoint.
:type azure_endpoint: Optional[str]
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: int
:param api_version: The API version to use.
:type api_version: str
"""
if name is None:
name = EncoderDefault.AZURE.value["language_model"]
super().__init__(name=name)
api_key = openai_api_key or os.getenv("AZURE_OPENAI_API_KEY")
if api_key is None:
raise ValueError("AzureOpenAI API key cannot be 'None'.")
azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError("Azure endpoint API key cannot be 'None'.")
try:
self._client = openai.AzureOpenAI(
api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version
)
except Exception as e:
raise ValueError(f"AzureOpenAI API client failed to initialize. Error: {e}")
self.temperature = temperature
self.max_tokens = max_tokens
def __call__(self, messages: List[Message]) -> str:
"""Call the AzureOpenAILLM.
:param messages: The messages to pass to the AzureOpenAILLM.
:type messages: List[Message]
:return: The response from the AzureOpenAILLM.
:rtype: str
"""
if self._client is None:
raise ValueError("AzureOpenAI client is not initialized.")
try:
completion = self._client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
output = completion.choices[0].message.content
if not output:
raise Exception("No output generated")
return output
except Exception as e:
logger.error(f"LLM error: {e}")
raise Exception(f"LLM error: {e}") from e