-
James Briggs authoredJames Briggs authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
openai.py 6.50 KiB
from asyncio import sleep as asleep
import os
from time import sleep
from typing import Any, List, Optional, Union
from pydantic.v1 import PrivateAttr
import openai
from openai import OpenAIError
from openai._types import NotGiven
from openai.types import CreateEmbeddingResponse
import tiktoken
from semantic_router.encoders import BaseEncoder
from semantic_router.schema import EncoderInfo
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger
model_configs = {
"text-embedding-ada-002": EncoderInfo(
name="text-embedding-ada-002", token_limit=8192
),
"text-embedding-3-small": EncoderInfo(
name="text-embedding-3-small", token_limit=8192
),
"text-embedding-3-large": EncoderInfo(
name="text-embedding-3-large", token_limit=8192
),
}
class OpenAIEncoder(BaseEncoder):
client: Optional[openai.Client]
async_client: Optional[openai.AsyncClient]
dimensions: Union[int, NotGiven] = NotGiven()
token_limit: int = 8192 # default value, should be replaced by config
_token_encoder: Any = PrivateAttr()
type: str = "openai"
def __init__(
self,
name: Optional[str] = None,
openai_base_url: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_org_id: Optional[str] = None,
score_threshold: float = 0.82,
dimensions: Union[int, NotGiven] = NotGiven(),
):
if name is None:
name = EncoderDefault.OPENAI.value["embedding_model"]
super().__init__(
name=name,
score_threshold=score_threshold,
)
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
if api_key is None:
raise ValueError("OpenAI API key cannot be 'None'.")
try:
self.client = openai.Client(
base_url=base_url, api_key=api_key, organization=openai_org_id
)
self.async_client = openai.AsyncClient(
base_url=base_url, api_key=api_key, organization=openai_org_id
)
except Exception as e:
raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}"
) from e
# set dimensions to support openai embed 3 dimensions param
self.dimensions = dimensions
# if model name is known, set token limit
if name in model_configs:
self.token_limit = model_configs[name].token_limit
# get token encoder
self._token_encoder = tiktoken.encoding_for_model(name)
def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
"""Encode a list of text documents into embeddings using OpenAI API.
:param docs: List of text documents to encode.
:param truncate: Whether to truncate the documents to token limit. If
False and a document exceeds the token limit, an error will be
raised.
:return: List of embeddings for each document."""
if self.client is None:
raise ValueError("OpenAI client is not initialized.")
embeds = None
error_message = ""
if truncate:
# check if any document exceeds token limit and truncate if so
docs = [self._truncate(doc) for doc in docs]
# Exponential backoff
for j in range(1, 7):
try:
embeds = self.client.embeddings.create(
input=docs,
model=self.name,
dimensions=self.dimensions,
)
if embeds.data:
break
except OpenAIError as e:
sleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
def _truncate(self, text: str) -> str:
# we use encode_ordinary as faster equivalent to encode(text, disallowed_special=())
tokens = self._token_encoder.encode_ordinary(text)
if len(tokens) > self.token_limit:
logger.warning(
f"Document exceeds token limit: {len(tokens)} > {self.token_limit}"
"\nTruncating document..."
)
text = self._token_encoder.decode(tokens[: self.token_limit - 1])
logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}")
return text
return text
async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
if self.async_client is None:
raise ValueError("OpenAI async client is not initialized.")
embeds = None
error_message = ""
if truncate:
# check if any document exceeds token limit and truncate if so
docs = [self._truncate(doc) for doc in docs]
# Exponential backoff
for j in range(1, 7):
try:
embeds = await self.async_client.embeddings.create(
input=docs,
model=self.name,
dimensions=self.dimensions,
)
if embeds.data:
break
except OpenAIError as e:
await asleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings