Skip to content
Snippets Groups Projects
Commit 0483b467 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

add max_retries to openai and azure encoders

parent 153bcebf
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -42,6 +42,7 @@ class OpenAIEncoder(BaseEncoder):
token_limit: int = 8192 # default value, should be replaced by config token_limit: int = 8192 # default value, should be replaced by config
_token_encoder: Any = PrivateAttr() _token_encoder: Any = PrivateAttr()
type: str = "openai" type: str = "openai"
max_retries: int
def __init__( def __init__(
self, self,
...@@ -51,9 +52,13 @@ class OpenAIEncoder(BaseEncoder): ...@@ -51,9 +52,13 @@ class OpenAIEncoder(BaseEncoder):
openai_org_id: Optional[str] = None, openai_org_id: Optional[str] = None,
score_threshold: Optional[float] = None, score_threshold: Optional[float] = None,
dimensions: Union[int, NotGiven] = NotGiven(), dimensions: Union[int, NotGiven] = NotGiven(),
max_retries: int | None = None,
): ):
if name is None: if name is None:
name = EncoderDefault.OPENAI.value["embedding_model"] name = EncoderDefault.OPENAI.value["embedding_model"]
max_retries = max_retries if max_retries is not None else 3
if score_threshold is None and name in model_configs: if score_threshold is None and name in model_configs:
set_score_threshold = model_configs[name].threshold set_score_threshold = model_configs[name].threshold
elif score_threshold is None: elif score_threshold is None:
...@@ -66,6 +71,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -66,6 +71,7 @@ class OpenAIEncoder(BaseEncoder):
super().__init__( super().__init__(
name=name, name=name,
score_threshold=set_score_threshold, score_threshold=set_score_threshold,
max_retries=max_retries,
) )
api_key = openai_api_key or os.getenv("OPENAI_API_KEY") api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
base_url = openai_base_url or os.getenv("OPENAI_BASE_URL") base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
...@@ -109,8 +115,9 @@ class OpenAIEncoder(BaseEncoder): ...@@ -109,8 +115,9 @@ class OpenAIEncoder(BaseEncoder):
docs = [self._truncate(doc) for doc in docs] docs = [self._truncate(doc) for doc in docs]
# Exponential backoff # Exponential backoff
for j in range(1, 7): for j in range(self.max_retries + 1):
try: try:
raise OpenAIError("Test")
embeds = self.client.embeddings.create( embeds = self.client.embeddings.create(
input=docs, input=docs,
model=self.name, model=self.name,
...@@ -119,12 +126,14 @@ class OpenAIEncoder(BaseEncoder): ...@@ -119,12 +126,14 @@ class OpenAIEncoder(BaseEncoder):
if embeds.data: if embeds.data:
break break
except OpenAIError as e: except OpenAIError as e:
sleep(2**j) logger.error("Exception occurred", exc_info=True)
error_message = str(e) if self.max_retries != 0:
logger.warning(f"Retrying in {2**j} seconds...") sleep(2**j)
logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
except Exception as e: except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}") logger.error(f"OpenAI API call failed. Error: {e}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e raise ValueError(f"OpenAI API call failed. Error: {str(e)}") from e
if ( if (
not embeds not embeds
...@@ -132,7 +141,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -132,7 +141,7 @@ class OpenAIEncoder(BaseEncoder):
or not embeds.data or not embeds.data
): ):
logger.info(f"Returned embeddings: {embeds}") logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}") raise ValueError(f"No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
...@@ -161,8 +170,9 @@ class OpenAIEncoder(BaseEncoder): ...@@ -161,8 +170,9 @@ class OpenAIEncoder(BaseEncoder):
docs = [self._truncate(doc) for doc in docs] docs = [self._truncate(doc) for doc in docs]
# Exponential backoff # Exponential backoff
for j in range(1, 7): for j in range(self.max_retries + 1):
try: try:
raise OpenAIError("Test")
embeds = await self.async_client.embeddings.create( embeds = await self.async_client.embeddings.create(
input=docs, input=docs,
model=self.name, model=self.name,
...@@ -171,11 +181,12 @@ class OpenAIEncoder(BaseEncoder): ...@@ -171,11 +181,12 @@ class OpenAIEncoder(BaseEncoder):
if embeds.data: if embeds.data:
break break
except OpenAIError as e: except OpenAIError as e:
await asleep(2**j) logger.error("Exception occurred", exc_info=True)
error_message = str(e) if self.max_retries != 0:
logger.warning(f"Retrying in {2**j} seconds...") await asleep(2**j)
logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
except Exception as e: except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}") logger.error(f"OpenAI API call failed. Error: {e}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e raise ValueError(f"OpenAI API call failed. Error: {e}") from e
if ( if (
...@@ -184,7 +195,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -184,7 +195,7 @@ class OpenAIEncoder(BaseEncoder):
or not embeds.data or not embeds.data
): ):
logger.info(f"Returned embeddings: {embeds}") logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}") raise ValueError(f"No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
...@@ -23,6 +23,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -23,6 +23,7 @@ class AzureOpenAIEncoder(BaseEncoder):
azure_endpoint: Optional[str] = None azure_endpoint: Optional[str] = None
api_version: Optional[str] = None api_version: Optional[str] = None
model: Optional[str] = None model: Optional[str] = None
max_retries: int
def __init__( def __init__(
self, self,
...@@ -33,11 +34,15 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -33,11 +34,15 @@ class AzureOpenAIEncoder(BaseEncoder):
model: Optional[str] = None, # TODO we should change to `name` JB model: Optional[str] = None, # TODO we should change to `name` JB
score_threshold: float = 0.82, score_threshold: float = 0.82,
dimensions: Union[int, NotGiven] = NotGiven(), dimensions: Union[int, NotGiven] = NotGiven(),
max_retries: int | None = None,
): ):
name = deployment_name name = deployment_name
if name is None: if name is None:
name = EncoderDefault.AZURE.value["embedding_model"] name = EncoderDefault.AZURE.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
max_retries = max_retries if max_retries is not None else 3
super().__init__(name=name, score_threshold=score_threshold, max_retries=max_retries)
self.api_key = api_key self.api_key = api_key
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.azure_endpoint = azure_endpoint self.azure_endpoint = azure_endpoint
...@@ -100,8 +105,9 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -100,8 +105,9 @@ class AzureOpenAIEncoder(BaseEncoder):
error_message = "" error_message = ""
# Exponential backoff # Exponential backoff
for j in range(3): for j in range(self.max_retries + 1):
try: try:
raise OpenAIError("Test")
embeds = self.client.embeddings.create( embeds = self.client.embeddings.create(
input=docs, input=docs,
model=str(self.model), model=str(self.model),
...@@ -110,15 +116,12 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -110,15 +116,12 @@ class AzureOpenAIEncoder(BaseEncoder):
if embeds.data: if embeds.data:
break break
except OpenAIError as e: except OpenAIError as e:
# print full traceback logger.error("Exception occurred", exc_info=True)
import traceback if self.max_retries != 0:
sleep(2**j)
traceback.print_exc() logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
sleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e: except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}") logger.error(f"Azure OpenAI API call failed. Error: {e}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
if ( if (
...@@ -126,7 +129,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -126,7 +129,7 @@ class AzureOpenAIEncoder(BaseEncoder):
or not isinstance(embeds, CreateEmbeddingResponse) or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data or not embeds.data
): ):
raise ValueError(f"No embeddings returned. Error: {error_message}") raise ValueError(f"No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
...@@ -138,8 +141,9 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -138,8 +141,9 @@ class AzureOpenAIEncoder(BaseEncoder):
error_message = "" error_message = ""
# Exponential backoff # Exponential backoff
for j in range(3): for j in range(self.max_retries + 1):
try: try:
raise OpenAIError("Test")
embeds = await self.async_client.embeddings.create( embeds = await self.async_client.embeddings.create(
input=docs, input=docs,
model=str(self.model), model=str(self.model),
...@@ -147,16 +151,14 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -147,16 +151,14 @@ class AzureOpenAIEncoder(BaseEncoder):
) )
if embeds.data: if embeds.data:
break break
except OpenAIError as e:
# print full traceback
import traceback
traceback.print_exc() except OpenAIError as e:
await asleep(2**j) logger.error("Exception occurred", exc_info=True)
error_message = str(e) if self.max_retries != 0:
logger.warning(f"Retrying in {2**j} seconds...") await asleep(2**j)
logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
except Exception as e: except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}") logger.error(f"Azure OpenAI API call failed. Error: {e}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
if ( if (
...@@ -164,7 +166,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -164,7 +166,7 @@ class AzureOpenAIEncoder(BaseEncoder):
or not isinstance(embeds, CreateEmbeddingResponse) or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data or not embeds.data
): ):
raise ValueError(f"No embeddings returned. Error: {error_message}") raise ValueError(f"No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment