Skip to content
Snippets Groups Projects
Commit 57d748dd authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

linting

parent 0483b467
No related branches found
No related tags found
No related merge requests found
...@@ -42,7 +42,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -42,7 +42,7 @@ class OpenAIEncoder(BaseEncoder):
token_limit: int = 8192 # default value, should be replaced by config token_limit: int = 8192 # default value, should be replaced by config
_token_encoder: Any = PrivateAttr() _token_encoder: Any = PrivateAttr()
type: str = "openai" type: str = "openai"
max_retries: int max_retries: int = 3
def __init__( def __init__(
self, self,
...@@ -56,9 +56,6 @@ class OpenAIEncoder(BaseEncoder): ...@@ -56,9 +56,6 @@ class OpenAIEncoder(BaseEncoder):
): ):
if name is None: if name is None:
name = EncoderDefault.OPENAI.value["embedding_model"] name = EncoderDefault.OPENAI.value["embedding_model"]
max_retries = max_retries if max_retries is not None else 3
if score_threshold is None and name in model_configs: if score_threshold is None and name in model_configs:
set_score_threshold = model_configs[name].threshold set_score_threshold = model_configs[name].threshold
elif score_threshold is None: elif score_threshold is None:
...@@ -71,13 +68,14 @@ class OpenAIEncoder(BaseEncoder): ...@@ -71,13 +68,14 @@ class OpenAIEncoder(BaseEncoder):
super().__init__( super().__init__(
name=name, name=name,
score_threshold=set_score_threshold, score_threshold=set_score_threshold,
max_retries=max_retries,
) )
api_key = openai_api_key or os.getenv("OPENAI_API_KEY") api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
base_url = openai_base_url or os.getenv("OPENAI_BASE_URL") base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID") openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
if api_key is None: if api_key is None:
raise ValueError("OpenAI API key cannot be 'None'.") raise ValueError("OpenAI API key cannot be 'None'.")
if max_retries is not None:
self.max_retries = max_retries
try: try:
self.client = openai.Client( self.client = openai.Client(
base_url=base_url, api_key=api_key, organization=openai_org_id base_url=base_url, api_key=api_key, organization=openai_org_id
...@@ -108,7 +106,6 @@ class OpenAIEncoder(BaseEncoder): ...@@ -108,7 +106,6 @@ class OpenAIEncoder(BaseEncoder):
if self.client is None: if self.client is None:
raise ValueError("OpenAI client is not initialized.") raise ValueError("OpenAI client is not initialized.")
embeds = None embeds = None
error_message = ""
if truncate: if truncate:
# check if any document exceeds token limit and truncate if so # check if any document exceeds token limit and truncate if so
...@@ -129,7 +126,9 @@ class OpenAIEncoder(BaseEncoder): ...@@ -129,7 +126,9 @@ class OpenAIEncoder(BaseEncoder):
logger.error("Exception occurred", exc_info=True) logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0: if self.max_retries != 0:
sleep(2**j) sleep(2**j)
logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
except Exception as e: except Exception as e:
logger.error(f"OpenAI API call failed. Error: {e}") logger.error(f"OpenAI API call failed. Error: {e}")
...@@ -141,7 +140,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -141,7 +140,7 @@ class OpenAIEncoder(BaseEncoder):
or not embeds.data or not embeds.data
): ):
logger.info(f"Returned embeddings: {embeds}") logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned.") raise ValueError("No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
...@@ -163,7 +162,6 @@ class OpenAIEncoder(BaseEncoder): ...@@ -163,7 +162,6 @@ class OpenAIEncoder(BaseEncoder):
if self.async_client is None: if self.async_client is None:
raise ValueError("OpenAI async client is not initialized.") raise ValueError("OpenAI async client is not initialized.")
embeds = None embeds = None
error_message = ""
if truncate: if truncate:
# check if any document exceeds token limit and truncate if so # check if any document exceeds token limit and truncate if so
...@@ -184,7 +182,9 @@ class OpenAIEncoder(BaseEncoder): ...@@ -184,7 +182,9 @@ class OpenAIEncoder(BaseEncoder):
logger.error("Exception occurred", exc_info=True) logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0: if self.max_retries != 0:
await asleep(2**j) await asleep(2**j)
logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
except Exception as e: except Exception as e:
logger.error(f"OpenAI API call failed. Error: {e}") logger.error(f"OpenAI API call failed. Error: {e}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e raise ValueError(f"OpenAI API call failed. Error: {e}") from e
...@@ -195,7 +195,7 @@ class OpenAIEncoder(BaseEncoder): ...@@ -195,7 +195,7 @@ class OpenAIEncoder(BaseEncoder):
or not embeds.data or not embeds.data
): ):
logger.info(f"Returned embeddings: {embeds}") logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned.") raise ValueError("No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
...@@ -23,7 +23,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -23,7 +23,7 @@ class AzureOpenAIEncoder(BaseEncoder):
azure_endpoint: Optional[str] = None azure_endpoint: Optional[str] = None
api_version: Optional[str] = None api_version: Optional[str] = None
model: Optional[str] = None model: Optional[str] = None
max_retries: int max_retries: int = 3
def __init__( def __init__(
self, self,
...@@ -39,10 +39,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -39,10 +39,7 @@ class AzureOpenAIEncoder(BaseEncoder):
name = deployment_name name = deployment_name
if name is None: if name is None:
name = EncoderDefault.AZURE.value["embedding_model"] name = EncoderDefault.AZURE.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
max_retries = max_retries if max_retries is not None else 3
super().__init__(name=name, score_threshold=score_threshold, max_retries=max_retries)
self.api_key = api_key self.api_key = api_key
self.deployment_name = deployment_name self.deployment_name = deployment_name
self.azure_endpoint = azure_endpoint self.azure_endpoint = azure_endpoint
...@@ -54,6 +51,8 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -54,6 +51,8 @@ class AzureOpenAIEncoder(BaseEncoder):
self.api_key = os.getenv("AZURE_OPENAI_API_KEY") self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
if self.api_key is None: if self.api_key is None:
raise ValueError("No Azure OpenAI API key provided.") raise ValueError("No Azure OpenAI API key provided.")
if max_retries is not None:
self.max_retries = max_retries
if self.deployment_name is None: if self.deployment_name is None:
self.deployment_name = EncoderDefault.AZURE.value["deployment_name"] self.deployment_name = EncoderDefault.AZURE.value["deployment_name"]
# deployment_name may still be None, but it is optional in the API # deployment_name may still be None, but it is optional in the API
...@@ -102,7 +101,6 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -102,7 +101,6 @@ class AzureOpenAIEncoder(BaseEncoder):
if self.client is None: if self.client is None:
raise ValueError("Azure OpenAI client is not initialized.") raise ValueError("Azure OpenAI client is not initialized.")
embeds = None embeds = None
error_message = ""
# Exponential backoff # Exponential backoff
for j in range(self.max_retries + 1): for j in range(self.max_retries + 1):
...@@ -119,7 +117,9 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -119,7 +117,9 @@ class AzureOpenAIEncoder(BaseEncoder):
logger.error("Exception occurred", exc_info=True) logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0: if self.max_retries != 0:
sleep(2**j) sleep(2**j)
logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
except Exception as e: except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {e}") logger.error(f"Azure OpenAI API call failed. Error: {e}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
...@@ -129,7 +129,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -129,7 +129,7 @@ class AzureOpenAIEncoder(BaseEncoder):
or not isinstance(embeds, CreateEmbeddingResponse) or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data or not embeds.data
): ):
raise ValueError(f"No embeddings returned.") raise ValueError("No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
...@@ -138,7 +138,6 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -138,7 +138,6 @@ class AzureOpenAIEncoder(BaseEncoder):
if self.async_client is None: if self.async_client is None:
raise ValueError("Azure OpenAI async client is not initialized.") raise ValueError("Azure OpenAI async client is not initialized.")
embeds = None embeds = None
error_message = ""
# Exponential backoff # Exponential backoff
for j in range(self.max_retries + 1): for j in range(self.max_retries + 1):
...@@ -156,7 +155,9 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -156,7 +155,9 @@ class AzureOpenAIEncoder(BaseEncoder):
logger.error("Exception occurred", exc_info=True) logger.error("Exception occurred", exc_info=True)
if self.max_retries != 0: if self.max_retries != 0:
await asleep(2**j) await asleep(2**j)
logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}") logger.warning(
f"Retrying in {2**j} seconds due to OpenAIError: {e}"
)
except Exception as e: except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {e}") logger.error(f"Azure OpenAI API call failed. Error: {e}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
...@@ -166,7 +167,7 @@ class AzureOpenAIEncoder(BaseEncoder): ...@@ -166,7 +167,7 @@ class AzureOpenAIEncoder(BaseEncoder):
or not isinstance(embeds, CreateEmbeddingResponse) or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data or not embeds.data
): ):
raise ValueError(f"No embeddings returned.") raise ValueError("No embeddings returned.")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings return embeddings
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment