Skip to content
Snippets Groups Projects
Unverified Commit 26ee2e72 authored by James Briggs's avatar James Briggs
Browse files

feat: add async to azure and oai encoders

parent 1916fecd
Branches
Tags
No related merge requests found
...@@ -7,9 +7,13 @@ class BaseEncoder(BaseModel): ...@@ -7,9 +7,13 @@ class BaseEncoder(BaseModel):
name: str name: str
score_threshold: float score_threshold: float
type: str = Field(default="base") type: str = Field(default="base")
is_async: bool = Field(default=False)
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __call__(self, docs: List[Any]) -> List[List[float]]: def __call__(self, docs: List[Any]) -> List[List[float]]:
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
def acall(self, docs: List[Any]) -> List[List[float]]:
raise NotImplementedError("Subclasses must implement this method")
from asyncio import sleep as asleep
import os import os
from time import sleep from time import sleep
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
...@@ -30,6 +31,7 @@ model_configs = { ...@@ -30,6 +31,7 @@ model_configs = {
class OpenAIEncoder(BaseEncoder): class OpenAIEncoder(BaseEncoder):
client: Optional[openai.Client] client: Optional[openai.Client]
async_client: Optional[openai.AsyncClient]
dimensions: Union[int, NotGiven] = NotGiven() dimensions: Union[int, NotGiven] = NotGiven()
token_limit: int = 8192 # default value, should be replaced by config token_limit: int = 8192 # default value, should be replaced by config
_token_encoder: Any = PrivateAttr() _token_encoder: Any = PrivateAttr()
...@@ -46,7 +48,10 @@ class OpenAIEncoder(BaseEncoder): ...@@ -46,7 +48,10 @@ class OpenAIEncoder(BaseEncoder):
): ):
if name is None: if name is None:
name = EncoderDefault.OPENAI.value["embedding_model"] name = EncoderDefault.OPENAI.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold) super().__init__(
name=name,
score_threshold=score_threshold,
)
api_key = openai_api_key or os.getenv("OPENAI_API_KEY") api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
base_url = openai_base_url or os.getenv("OPENAI_BASE_URL") base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID") openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
...@@ -56,6 +61,9 @@ class OpenAIEncoder(BaseEncoder): ...@@ -56,6 +61,9 @@ class OpenAIEncoder(BaseEncoder):
self.client = openai.Client( self.client = openai.Client(
base_url=base_url, api_key=api_key, organization=openai_org_id base_url=base_url, api_key=api_key, organization=openai_org_id
) )
self.async_client = openai.AsyncClient(
base_url=base_url, api_key=api_key, organization=openai_org_id
)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}" f"OpenAI API client failed to initialize. Error: {e}"
...@@ -126,3 +134,43 @@ class OpenAIEncoder(BaseEncoder): ...@@ -126,3 +134,43 @@ class OpenAIEncoder(BaseEncoder):
logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}") logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}")
return text return text
return text return text
async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
if self.async_client is None:
raise ValueError("OpenAI async client is not initialized.")
embeds = None
error_message = ""
if truncate:
# check if any document exceeds token limit and truncate if so
docs = [self._truncate(doc) for doc in docs]
# Exponential backoff
for j in range(1, 7):
try:
embeds = await self.async_client.embeddings.create(
input=docs,
model=self.name,
dimensions=self.dimensions,
)
if embeds.data:
break
except OpenAIError as e:
await asleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
from asyncio import sleep as asleep
import os import os
from time import sleep from time import sleep
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -14,6 +15,7 @@ from semantic_router.utils.logger import logger ...@@ -14,6 +15,7 @@ from semantic_router.utils.logger import logger
class AzureOpenAIEncoder(BaseEncoder): class AzureOpenAIEncoder(BaseEncoder):
client: Optional[openai.AzureOpenAI] = None client: Optional[openai.AzureOpenAI] = None
async_client: Optional[openai.AsyncAzureOpenAI] = None
dimensions: Union[int, NotGiven] = NotGiven() dimensions: Union[int, NotGiven] = NotGiven()
type: str = "azure" type: str = "azure"
api_key: Optional[str] = None api_key: Optional[str] = None
...@@ -77,7 +79,14 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -77,7 +79,14 @@ class AzureOpenAIEncoder(BaseEncoder):
api_key=str(self.api_key), api_key=str(self.api_key),
azure_endpoint=str(self.azure_endpoint), azure_endpoint=str(self.azure_endpoint),
api_version=str(self.api_version), api_version=str(self.api_version),
# _strict_response_validation=True, )
self.async_client = openai.AsyncAzureOpenAI(
azure_deployment=(
str(self.deployment_name) if self.deployment_name else None
),
api_key=str(self.api_key),
azure_endpoint=str(self.azure_endpoint),
api_version=str(self.api_version),
) )
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
...@@ -86,7 +95,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -86,7 +95,7 @@ class AzureOpenAIEncoder(BaseEncoder):
def __call__(self, docs: List[str]) -> List[List[float]]: def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None: if self.client is None:
raise ValueError("OpenAI client is not initialized.") raise ValueError("Azure OpenAI client is not initialized.")
embeds = None embeds = None
error_message = "" error_message = ""
...@@ -121,3 +130,41 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -121,3 +130,41 @@ class AzureOpenAIEncoder(BaseEncoder):
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
async def acall(self, docs: List[str]) -> List[List[float]]:
if self.async_client is None:
raise ValueError("Azure OpenAI async client is not initialized.")
embeds = None
error_message = ""
# Exponential backoff
for j in range(3):
try:
embeds = await self.async_client.embeddings.create(
input=docs,
model=str(self.model),
dimensions=self.dimensions,
)
if embeds.data:
break
except OpenAIError as e:
# print full traceback
import traceback
traceback.print_exc()
await asleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment