From 26ee2e7218881df546148ad93ea794faaac43fd5 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Wed, 12 Jun 2024 16:42:25 +0800
Subject: [PATCH] feat: add async to azure and oai encoders

---
 semantic_router/encoders/base.py   |  4 +++
 semantic_router/encoders/openai.py | 50 ++++++++++++++++++++++++++++-
 semantic_router/encoders/zure.py   | 51 ++++++++++++++++++++++++++++--
 3 files changed, 102 insertions(+), 3 deletions(-)

diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index b1e1311e..70a3f6ee 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -7,9 +7,13 @@ class BaseEncoder(BaseModel):
     name: str
     score_threshold: float
     type: str = Field(default="base")
+    is_async: bool = Field(default=False)
 
     class Config:
         arbitrary_types_allowed = True
 
     def __call__(self, docs: List[Any]) -> List[List[float]]:
         raise NotImplementedError("Subclasses must implement this method")
+    
+    def acall(self, docs: List[Any]) -> List[List[float]]:
+        raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index bd68727f..96303155 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -1,3 +1,4 @@
+from asyncio import sleep as asleep
 import os
 from time import sleep
 from typing import Any, List, Optional, Union
@@ -30,6 +31,7 @@ model_configs = {
 
 class OpenAIEncoder(BaseEncoder):
     client: Optional[openai.Client]
+    async_client: Optional[openai.AsyncClient]
     dimensions: Union[int, NotGiven] = NotGiven()
     token_limit: int = 8192  # default value, should be replaced by config
     _token_encoder: Any = PrivateAttr()
@@ -46,7 +48,10 @@ class OpenAIEncoder(BaseEncoder):
     ):
         if name is None:
             name = EncoderDefault.OPENAI.value["embedding_model"]
-        super().__init__(name=name, score_threshold=score_threshold)
+        super().__init__(
+            name=name,
+            score_threshold=score_threshold,
+        )
         api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
         base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
         openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
@@ -56,6 +61,9 @@ class OpenAIEncoder(BaseEncoder):
             self.client = openai.Client(
                 base_url=base_url, api_key=api_key, organization=openai_org_id
             )
+            self.async_client = openai.AsyncClient(
+                base_url=base_url, api_key=api_key, organization=openai_org_id
+            )
         except Exception as e:
             raise ValueError(
                 f"OpenAI API client failed to initialize. Error: {e}"
@@ -126,3 +134,43 @@ class OpenAIEncoder(BaseEncoder):
             logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}")
             return text
         return text
+
+    async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
+        if self.async_client is None:
+            raise ValueError("OpenAI async client is not initialized.")
+        embeds = None
+        error_message = ""
+
+        if truncate:
+            # check if any document exceeds token limit and truncate if so
+            docs = [self._truncate(doc) for doc in docs]
+
+        # Exponential backoff
+        for j in range(1, 7):
+            try:
+                embeds = await self.async_client.embeddings.create(
+                    input=docs,
+                    model=self.name,
+                    dimensions=self.dimensions,
+                )
+                if embeds.data:
+                    break
+            except OpenAIError as e:
+                await asleep(2**j)
+                error_message = str(e)
+                logger.warning(f"Retrying in {2**j} seconds...")
+            except Exception as e:
+                logger.error(f"OpenAI API call failed. Error: {error_message}")
+                raise ValueError(f"OpenAI API call failed. Error: {e}") from e
+
+        if (
+            not embeds
+            or not isinstance(embeds, CreateEmbeddingResponse)
+            or not embeds.data
+        ):
+            logger.info(f"Returned embeddings: {embeds}")
+            raise ValueError(f"No embeddings returned. Error: {error_message}")
+
+        embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
+        return embeddings
+
diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py
index 55b2a40b..6968583a 100644
--- a/semantic_router/encoders/zure.py
+++ b/semantic_router/encoders/zure.py
@@ -1,3 +1,4 @@
+from asyncio import sleep as asleep
 import os
 from time import sleep
 from typing import List, Optional, Union
@@ -14,6 +15,7 @@ from semantic_router.utils.logger import logger
 
 class AzureOpenAIEncoder(BaseEncoder):
     client: Optional[openai.AzureOpenAI] = None
+    async_client: Optional[openai.AsyncAzureOpenAI] = None
     dimensions: Union[int, NotGiven] = NotGiven()
     type: str = "azure"
     api_key: Optional[str] = None
@@ -77,7 +79,14 @@ class AzureOpenAIEncoder(BaseEncoder):
                 api_key=str(self.api_key),
                 azure_endpoint=str(self.azure_endpoint),
                 api_version=str(self.api_version),
-                # _strict_response_validation=True,
+            )
+            self.async_client = openai.AsyncAzureOpenAI(
+                azure_deployment=(
+                    str(self.deployment_name) if self.deployment_name else None
+                ),
+                api_key=str(self.api_key),
+                azure_endpoint=str(self.azure_endpoint),
+                api_version=str(self.api_version),
             )
         except Exception as e:
             raise ValueError(
@@ -86,7 +95,7 @@ class AzureOpenAIEncoder(BaseEncoder):
 
     def __call__(self, docs: List[str]) -> List[List[float]]:
         if self.client is None:
-            raise ValueError("OpenAI client is not initialized.")
+            raise ValueError("Azure OpenAI client is not initialized.")
         embeds = None
         error_message = ""
 
@@ -121,3 +130,41 @@ class AzureOpenAIEncoder(BaseEncoder):
 
         embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
         return embeddings
+    
+    async def acall(self, docs: List[str]) -> List[List[float]]:
+        if self.async_client is None:
+            raise ValueError("Azure OpenAI async client is not initialized.")
+        embeds = None
+        error_message = ""
+
+        # Exponential backoff
+        for j in range(3):
+            try:
+                embeds = await self.async_client.embeddings.create(
+                    input=docs,
+                    model=str(self.model),
+                    dimensions=self.dimensions,
+                )
+                if embeds.data:
+                    break
+            except OpenAIError as e:
+                # print full traceback
+                import traceback
+
+                traceback.print_exc()
+                await asleep(2**j)
+                error_message = str(e)
+                logger.warning(f"Retrying in {2**j} seconds...")
+            except Exception as e:
+                logger.error(f"Azure OpenAI API call failed. Error: {error_message}")
+                raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e
+
+        if (
+            not embeds
+            or not isinstance(embeds, CreateEmbeddingResponse)
+            or not embeds.data
+        ):
+            raise ValueError(f"No embeddings returned. Error: {error_message}")
+
+        embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
+        return embeddings
-- 
GitLab