diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index e4acc5a493d26161cb8727f478a68491ead624ae..0425d5df80e4cf586fc5e668067d2d3f5928bad1 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -42,6 +42,7 @@ class OpenAIEncoder(BaseEncoder): token_limit: int = 8192 # default value, should be replaced by config _token_encoder: Any = PrivateAttr() type: str = "openai" + max_retries: int def __init__( self, @@ -51,9 +52,13 @@ class OpenAIEncoder(BaseEncoder): openai_org_id: Optional[str] = None, score_threshold: Optional[float] = None, dimensions: Union[int, NotGiven] = NotGiven(), + max_retries: int | None = None, ): if name is None: name = EncoderDefault.OPENAI.value["embedding_model"] + + max_retries = max_retries if max_retries is not None else 3 + if score_threshold is None and name in model_configs: set_score_threshold = model_configs[name].threshold elif score_threshold is None: @@ -66,6 +71,7 @@ class OpenAIEncoder(BaseEncoder): super().__init__( name=name, score_threshold=set_score_threshold, + max_retries=max_retries, ) api_key = openai_api_key or os.getenv("OPENAI_API_KEY") base_url = openai_base_url or os.getenv("OPENAI_BASE_URL") @@ -109,8 +115,9 @@ class OpenAIEncoder(BaseEncoder): docs = [self._truncate(doc) for doc in docs] # Exponential backoff - for j in range(1, 7): + for j in range(self.max_retries + 1): try: + raise OpenAIError("Test") embeds = self.client.embeddings.create( input=docs, model=self.name, @@ -119,12 +126,14 @@ class OpenAIEncoder(BaseEncoder): if embeds.data: break except OpenAIError as e: - sleep(2**j) - error_message = str(e) - logger.warning(f"Retrying in {2**j} seconds...") + logger.error("Exception occurred", exc_info=True) + if self.max_retries != 0: + sleep(2**j) + logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") + except Exception as e: - logger.error(f"OpenAI API call failed. Error: {error_message}") - raise ValueError(f"OpenAI API call failed. Error: {e}") from e + logger.error(f"OpenAI API call failed. Error: {e}") + raise ValueError(f"OpenAI API call failed. Error: {str(e)}") from e if ( not embeds @@ -132,7 +141,7 @@ class OpenAIEncoder(BaseEncoder): or not embeds.data ): logger.info(f"Returned embeddings: {embeds}") - raise ValueError(f"No embeddings returned. Error: {error_message}") + raise ValueError(f"No embeddings returned.") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings @@ -161,8 +170,9 @@ class OpenAIEncoder(BaseEncoder): docs = [self._truncate(doc) for doc in docs] # Exponential backoff - for j in range(1, 7): + for j in range(self.max_retries + 1): try: + raise OpenAIError("Test") embeds = await self.async_client.embeddings.create( input=docs, model=self.name, @@ -171,11 +181,12 @@ class OpenAIEncoder(BaseEncoder): if embeds.data: break except OpenAIError as e: - await asleep(2**j) - error_message = str(e) - logger.warning(f"Retrying in {2**j} seconds...") + logger.error("Exception occurred", exc_info=True) + if self.max_retries != 0: + await asleep(2**j) + logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") except Exception as e: - logger.error(f"OpenAI API call failed. Error: {error_message}") + logger.error(f"OpenAI API call failed. Error: {e}") raise ValueError(f"OpenAI API call failed. Error: {e}") from e if ( @@ -184,7 +195,7 @@ class OpenAIEncoder(BaseEncoder): or not embeds.data ): logger.info(f"Returned embeddings: {embeds}") - raise ValueError(f"No embeddings returned. Error: {error_message}") + raise ValueError(f"No embeddings returned.") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index dba936004d3f9faf83fa61be9f87056aac3f7021..3c1996923c958f5b26728765bf097c83f6deaf07 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -23,6 +23,7 @@ class AzureOpenAIEncoder(BaseEncoder): azure_endpoint: Optional[str] = None api_version: Optional[str] = None model: Optional[str] = None + max_retries: int def __init__( self, @@ -33,11 +34,15 @@ class AzureOpenAIEncoder(BaseEncoder): model: Optional[str] = None, # TODO we should change to `name` JB score_threshold: float = 0.82, dimensions: Union[int, NotGiven] = NotGiven(), + max_retries: int | None = None, ): name = deployment_name if name is None: name = EncoderDefault.AZURE.value["embedding_model"] - super().__init__(name=name, score_threshold=score_threshold) + + max_retries = max_retries if max_retries is not None else 3 + + super().__init__(name=name, score_threshold=score_threshold, max_retries=max_retries) self.api_key = api_key self.deployment_name = deployment_name self.azure_endpoint = azure_endpoint @@ -100,8 +105,9 @@ class AzureOpenAIEncoder(BaseEncoder): error_message = "" # Exponential backoff - for j in range(3): + for j in range(self.max_retries + 1): try: + raise OpenAIError("Test") embeds = self.client.embeddings.create( input=docs, model=str(self.model), @@ -110,15 +116,12 @@ class AzureOpenAIEncoder(BaseEncoder): if embeds.data: break except OpenAIError as e: - # print full traceback - import traceback - - traceback.print_exc() - sleep(2**j) - error_message = str(e) - logger.warning(f"Retrying in {2**j} seconds...") + logger.error("Exception occurred", exc_info=True) + if self.max_retries != 0: + sleep(2**j) + logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") except Exception as e: - logger.error(f"Azure OpenAI API call failed. Error: {error_message}") + logger.error(f"Azure OpenAI API call failed. Error: {e}") raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e if ( @@ -126,7 +129,7 @@ class AzureOpenAIEncoder(BaseEncoder): or not isinstance(embeds, CreateEmbeddingResponse) or not embeds.data ): - raise ValueError(f"No embeddings returned. Error: {error_message}") + raise ValueError(f"No embeddings returned.") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings @@ -138,8 +141,9 @@ class AzureOpenAIEncoder(BaseEncoder): error_message = "" # Exponential backoff - for j in range(3): + for j in range(self.max_retries + 1): try: + raise OpenAIError("Test") embeds = await self.async_client.embeddings.create( input=docs, model=str(self.model), @@ -147,16 +151,14 @@ class AzureOpenAIEncoder(BaseEncoder): ) if embeds.data: break - except OpenAIError as e: - # print full traceback - import traceback - traceback.print_exc() - await asleep(2**j) - error_message = str(e) - logger.warning(f"Retrying in {2**j} seconds...") + except OpenAIError as e: + logger.error("Exception occurred", exc_info=True) + if self.max_retries != 0: + await asleep(2**j) + logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") except Exception as e: - logger.error(f"Azure OpenAI API call failed. Error: {error_message}") + logger.error(f"Azure OpenAI API call failed. Error: {e}") raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e if ( @@ -164,7 +166,7 @@ class AzureOpenAIEncoder(BaseEncoder): or not isinstance(embeds, CreateEmbeddingResponse) or not embeds.data ): - raise ValueError(f"No embeddings returned. Error: {error_message}") + raise ValueError(f"No embeddings returned.") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings