From 57d748dde82b3ceaae0f16b9a1b1cdaf4e88e606 Mon Sep 17 00:00:00 2001
From: Ismail Ashraq <ismailashraq@Ismails-MacBook-Pro.local>
Date: Sun, 18 Aug 2024 16:14:17 +0500
Subject: [PATCH] linting

---
 semantic_router/encoders/openai.py | 22 +++++++++++-----------
 semantic_router/encoders/zure.py   | 23 ++++++++++++-----------
 2 files changed, 23 insertions(+), 22 deletions(-)

diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index 0425d5df..8ba7ee5b 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -42,7 +42,7 @@ class OpenAIEncoder(BaseEncoder):
     token_limit: int = 8192  # default value, should be replaced by config
     _token_encoder: Any = PrivateAttr()
     type: str = "openai"
-    max_retries: int
+    max_retries: int = 3
 
     def __init__(
         self,
@@ -56,9 +56,6 @@ class OpenAIEncoder(BaseEncoder):
     ):
         if name is None:
             name = EncoderDefault.OPENAI.value["embedding_model"]
-
-        max_retries = max_retries if max_retries is not None else 3
-
         if score_threshold is None and name in model_configs:
             set_score_threshold = model_configs[name].threshold
         elif score_threshold is None:
@@ -71,13 +68,14 @@ class OpenAIEncoder(BaseEncoder):
         super().__init__(
             name=name,
             score_threshold=set_score_threshold,
-            max_retries=max_retries,
         )
         api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
         base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
         openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
         if api_key is None:
             raise ValueError("OpenAI API key cannot be 'None'.")
+        if max_retries is not None:
+            self.max_retries = max_retries
         try:
             self.client = openai.Client(
                 base_url=base_url, api_key=api_key, organization=openai_org_id
@@ -108,7 +106,6 @@ class OpenAIEncoder(BaseEncoder):
         if self.client is None:
             raise ValueError("OpenAI client is not initialized.")
         embeds = None
-        error_message = ""
 
         if truncate:
             # check if any document exceeds token limit and truncate if so
@@ -129,7 +126,9 @@ class OpenAIEncoder(BaseEncoder):
                 logger.error("Exception occurred", exc_info=True)
                 if self.max_retries != 0:
                     sleep(2**j)
-                    logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
+                    logger.warning(
+                        f"Retrying in {2**j} seconds due to OpenAIError: {e}"
+                    )
 
             except Exception as e:
                 logger.error(f"OpenAI API call failed. Error: {e}")
@@ -141,7 +140,7 @@ class OpenAIEncoder(BaseEncoder):
             or not embeds.data
         ):
             logger.info(f"Returned embeddings: {embeds}")
-            raise ValueError(f"No embeddings returned.")
+            raise ValueError("No embeddings returned.")
 
         embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
         return embeddings
@@ -163,7 +162,6 @@ class OpenAIEncoder(BaseEncoder):
         if self.async_client is None:
             raise ValueError("OpenAI async client is not initialized.")
         embeds = None
-        error_message = ""
 
         if truncate:
             # check if any document exceeds token limit and truncate if so
@@ -184,7 +182,9 @@ class OpenAIEncoder(BaseEncoder):
                 logger.error("Exception occurred", exc_info=True)
                 if self.max_retries != 0:
                     await asleep(2**j)
-                    logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
+                    logger.warning(
+                        f"Retrying in {2**j} seconds due to OpenAIError: {e}"
+                    )
             except Exception as e:
                 logger.error(f"OpenAI API call failed. Error: {e}")
                 raise ValueError(f"OpenAI API call failed. Error: {e}") from e
@@ -195,7 +195,7 @@ class OpenAIEncoder(BaseEncoder):
             or not embeds.data
         ):
             logger.info(f"Returned embeddings: {embeds}")
-            raise ValueError(f"No embeddings returned.")
+            raise ValueError("No embeddings returned.")
 
         embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
         return embeddings
diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py
index 3c199692..e04b55bf 100644
--- a/semantic_router/encoders/zure.py
+++ b/semantic_router/encoders/zure.py
@@ -23,7 +23,7 @@ class AzureOpenAIEncoder(BaseEncoder):
     azure_endpoint: Optional[str] = None
     api_version: Optional[str] = None
     model: Optional[str] = None
-    max_retries: int
+    max_retries: int = 3
 
     def __init__(
         self,
@@ -39,10 +39,7 @@ class AzureOpenAIEncoder(BaseEncoder):
         name = deployment_name
         if name is None:
             name = EncoderDefault.AZURE.value["embedding_model"]
-
-        max_retries = max_retries if max_retries is not None else 3
-        
-        super().__init__(name=name, score_threshold=score_threshold, max_retries=max_retries)
+        super().__init__(name=name, score_threshold=score_threshold)
         self.api_key = api_key
         self.deployment_name = deployment_name
         self.azure_endpoint = azure_endpoint
@@ -54,6 +51,8 @@ class AzureOpenAIEncoder(BaseEncoder):
             self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
             if self.api_key is None:
                 raise ValueError("No Azure OpenAI API key provided.")
+        if max_retries is not None:
+            self.max_retries = max_retries
         if self.deployment_name is None:
             self.deployment_name = EncoderDefault.AZURE.value["deployment_name"]
         # deployment_name may still be None, but it is optional in the API
@@ -102,7 +101,6 @@ class AzureOpenAIEncoder(BaseEncoder):
         if self.client is None:
             raise ValueError("Azure OpenAI client is not initialized.")
         embeds = None
-        error_message = ""
 
         # Exponential backoff
         for j in range(self.max_retries + 1):
@@ -119,7 +117,9 @@ class AzureOpenAIEncoder(BaseEncoder):
                 logger.error("Exception occurred", exc_info=True)
                 if self.max_retries != 0:
                     sleep(2**j)
-                    logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
+                    logger.warning(
+                        f"Retrying in {2**j} seconds due to OpenAIError: {e}"
+                    )
             except Exception as e:
                 logger.error(f"Azure OpenAI API call failed. Error: {e}")
                 raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
@@ -129,7 +129,7 @@ class AzureOpenAIEncoder(BaseEncoder):
             or not isinstance(embeds, CreateEmbeddingResponse)
             or not embeds.data
         ):
-            raise ValueError(f"No embeddings returned.")
+            raise ValueError("No embeddings returned.")
 
         embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
         return embeddings
@@ -138,7 +138,6 @@ class AzureOpenAIEncoder(BaseEncoder):
         if self.async_client is None:
             raise ValueError("Azure OpenAI async client is not initialized.")
         embeds = None
-        error_message = ""
 
         # Exponential backoff
         for j in range(self.max_retries + 1):
@@ -156,7 +155,9 @@ class AzureOpenAIEncoder(BaseEncoder):
                 logger.error("Exception occurred", exc_info=True)
                 if self.max_retries != 0:
                     await asleep(2**j)
-                    logger.warning(f"Retrying in {2**j} seconds due to OpenAIError: {e}")
+                    logger.warning(
+                        f"Retrying in {2**j} seconds due to OpenAIError: {e}"
+                    )
             except Exception as e:
                 logger.error(f"Azure OpenAI API call failed. Error: {e}")
                 raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
@@ -166,7 +167,7 @@ class AzureOpenAIEncoder(BaseEncoder):
             or not isinstance(embeds, CreateEmbeddingResponse)
             or not embeds.data
         ):
-            raise ValueError(f"No embeddings returned.")
+            raise ValueError("No embeddings returned.")
 
         embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
         return embeddings
-- 
GitLab