diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index b1e1311ec6b27fcf6984d5451696f61425b2103e..70a3f6ee12904d19e23466b98a82e2f25f03543e 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -7,9 +7,13 @@ class BaseEncoder(BaseModel): name: str score_threshold: float type: str = Field(default="base") + is_async: bool = Field(default=False) class Config: arbitrary_types_allowed = True def __call__(self, docs: List[Any]) -> List[List[float]]: raise NotImplementedError("Subclasses must implement this method") + + def acall(self, docs: List[Any]) -> List[List[float]]: + raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index bd68727f1a0d616932486d5221cd96763a4bb7ae..96303155151c4e43b5e3937f801f42120414be7f 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -1,3 +1,4 @@ +from asyncio import sleep as asleep import os from time import sleep from typing import Any, List, Optional, Union @@ -30,6 +31,7 @@ model_configs = { class OpenAIEncoder(BaseEncoder): client: Optional[openai.Client] + async_client: Optional[openai.AsyncClient] dimensions: Union[int, NotGiven] = NotGiven() token_limit: int = 8192 # default value, should be replaced by config _token_encoder: Any = PrivateAttr() @@ -46,7 +48,10 @@ class OpenAIEncoder(BaseEncoder): ): if name is None: name = EncoderDefault.OPENAI.value["embedding_model"] - super().__init__(name=name, score_threshold=score_threshold) + super().__init__( + name=name, + score_threshold=score_threshold, + ) api_key = openai_api_key or os.getenv("OPENAI_API_KEY") base_url = openai_base_url or os.getenv("OPENAI_BASE_URL") openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID") @@ -56,6 +61,9 @@ class OpenAIEncoder(BaseEncoder): self.client = openai.Client( base_url=base_url, api_key=api_key, organization=openai_org_id ) + self.async_client = openai.AsyncClient( + base_url=base_url, api_key=api_key, organization=openai_org_id + ) except Exception as e: raise ValueError( f"OpenAI API client failed to initialize. Error: {e}" @@ -126,3 +134,43 @@ class OpenAIEncoder(BaseEncoder): logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}") return text return text + + async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float]]: + if self.async_client is None: + raise ValueError("OpenAI async client is not initialized.") + embeds = None + error_message = "" + + if truncate: + # check if any document exceeds token limit and truncate if so + docs = [self._truncate(doc) for doc in docs] + + # Exponential backoff + for j in range(1, 7): + try: + embeds = await self.async_client.embeddings.create( + input=docs, + model=self.name, + dimensions=self.dimensions, + ) + if embeds.data: + break + except OpenAIError as e: + await asleep(2**j) + error_message = str(e) + logger.warning(f"Retrying in {2**j} seconds...") + except Exception as e: + logger.error(f"OpenAI API call failed. Error: {error_message}") + raise ValueError(f"OpenAI API call failed. Error: {e}") from e + + if ( + not embeds + or not isinstance(embeds, CreateEmbeddingResponse) + or not embeds.data + ): + logger.info(f"Returned embeddings: {embeds}") + raise ValueError(f"No embeddings returned. Error: {error_message}") + + embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] + return embeddings + diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index 55b2a40bc8508b2b7f296598165f59a298619490..6968583aacde861250a2a8155ad414732bb2b9eb 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -1,3 +1,4 @@ +from asyncio import sleep as asleep import os from time import sleep from typing import List, Optional, Union @@ -14,6 +15,7 @@ from semantic_router.utils.logger import logger class AzureOpenAIEncoder(BaseEncoder): client: Optional[openai.AzureOpenAI] = None + async_client: Optional[openai.AsyncAzureOpenAI] = None dimensions: Union[int, NotGiven] = NotGiven() type: str = "azure" api_key: Optional[str] = None @@ -77,7 +79,14 @@ class AzureOpenAIEncoder(BaseEncoder): api_key=str(self.api_key), azure_endpoint=str(self.azure_endpoint), api_version=str(self.api_version), - # _strict_response_validation=True, + ) + self.async_client = openai.AsyncAzureOpenAI( + azure_deployment=( + str(self.deployment_name) if self.deployment_name else None + ), + api_key=str(self.api_key), + azure_endpoint=str(self.azure_endpoint), + api_version=str(self.api_version), ) except Exception as e: raise ValueError( @@ -86,7 +95,7 @@ class AzureOpenAIEncoder(BaseEncoder): def __call__(self, docs: List[str]) -> List[List[float]]: if self.client is None: - raise ValueError("OpenAI client is not initialized.") + raise ValueError("Azure OpenAI client is not initialized.") embeds = None error_message = "" @@ -121,3 +130,41 @@ class AzureOpenAIEncoder(BaseEncoder): embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings + + async def acall(self, docs: List[str]) -> List[List[float]]: + if self.async_client is None: + raise ValueError("Azure OpenAI async client is not initialized.") + embeds = None + error_message = "" + + # Exponential backoff + for j in range(3): + try: + embeds = await self.async_client.embeddings.create( + input=docs, + model=str(self.model), + dimensions=self.dimensions, + ) + if embeds.data: + break + except OpenAIError as e: + # print full traceback + import traceback + + traceback.print_exc() + await asleep(2**j) + error_message = str(e) + logger.warning(f"Retrying in {2**j} seconds...") + except Exception as e: + logger.error(f"Azure OpenAI API call failed. Error: {error_message}") + raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e + + if ( + not embeds + or not isinstance(embeds, CreateEmbeddingResponse) + or not embeds.data + ): + raise ValueError(f"No embeddings returned. Error: {error_message}") + + embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] + return embeddings