From 78856bc365c7678b1defed95ed4ad35f29149eaf Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Sun, 28 Apr 2024 16:21:00 +0800
Subject: [PATCH] feat: refactor for AutoEncoder

---
 semantic_router/encoders/__init__.py | 45 ++++++++++++++++++++++++++++
 semantic_router/encoders/openai.py   | 33 ++++++++------------
 semantic_router/encoders/zure.py     |  2 +-
 semantic_router/layer.py             | 10 +++----
 semantic_router/schema.py            | 13 +++++---
 tests/unit/encoders/test_openai.py   |  1 +
 tests/unit/test_schema.py            | 36 ----------------------
 7 files changed, 73 insertions(+), 67 deletions(-)

diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index a79fa605..5efc7303 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -1,3 +1,5 @@
+from typing import List, Optional
+
 from semantic_router.encoders.base import BaseEncoder
 from semantic_router.encoders.bm25 import BM25Encoder
 from semantic_router.encoders.clip import CLIPEncoder
@@ -11,6 +13,7 @@ from semantic_router.encoders.openai import OpenAIEncoder
 from semantic_router.encoders.tfidf import TfidfEncoder
 from semantic_router.encoders.vit import VitEncoder
 from semantic_router.encoders.zure import AzureOpenAIEncoder
+from semantic_router.schema import EncoderType
 
 __all__ = [
     "BaseEncoder",
@@ -27,3 +30,45 @@ __all__ = [
     "CLIPEncoder",
     "GoogleEncoder",
 ]
+
+
+class AutoEncoder:
+    type: EncoderType
+    name: Optional[str]
+    model: BaseEncoder
+
+    def __init__(self, type: str, name: Optional[str]):
+        self.type = EncoderType(type)
+        self.name = name
+        if self.type == EncoderType.AZURE:
+            # TODO should change `model` to `name` JB
+            self.model = AzureOpenAIEncoder(model=name)
+        elif self.type == EncoderType.COHERE:
+            self.model = CohereEncoder(name=name)
+        elif self.type == EncoderType.OPENAI:
+            self.model = OpenAIEncoder(name=name)
+        elif self.type == EncoderType.BM25:
+            if name is None:
+                name = "bm25"
+            self.model = BM25Encoder(name=name)
+        elif self.type == EncoderType.TFIDF:
+            if name is None:
+                name = "tfidf"
+            self.model = TfidfEncoder(name=name)
+        elif self.type == EncoderType.FASTEMBED:
+            self.model = FastEmbedEncoder(name=name)
+        elif self.type == EncoderType.HUGGINGFACE:
+            self.model = HuggingFaceEncoder(name=name)
+        elif self.type == EncoderType.MISTRAL:
+            self.model = MistralEncoder(name=name)
+        elif self.type == EncoderType.VIT:
+            self.model = VitEncoder(name=name)
+        elif self.type == EncoderType.CLIP:
+            self.model = CLIPEncoder(name=name)
+        elif self.type == EncoderType.GOOGLE:
+            self.model = GoogleEncoder(name=name)
+        else:
+            raise ValueError(f"Encoder type '{type}' not supported")
+
+    def __call__(self, texts: List[str]) -> List[List[float]]:
+        return self.model(texts)
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index 14b9bb7b..d56a1e71 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -1,6 +1,7 @@
 import os
 from time import sleep
 from typing import Any, List, Optional, Union
+from pydantic.v1 import PrivateAttr
 
 import openai
 from openai import OpenAIError
@@ -16,28 +17,18 @@ from semantic_router.utils.logger import logger
 
 model_configs = {
     "text-embedding-ada-002": EncoderInfo(
-        name="text-embedding-ada-002",
-        type="openai",
-        token_limit=4000
+        name="text-embedding-ada-002", token_limit=8192
     ),
-    "text-embed-3-small": EncoderInfo(
-        name="text-embed-3-small",
-        type="openai",
-        token_limit=8192
-    ),
-    "text-embed-3-large": EncoderInfo(
-        name="text-embed-3-large",
-        type="openai",
-        token_limit=8192
-    )
+    "text-embed-3-small": EncoderInfo(name="text-embed-3-small", token_limit=8192),
+    "text-embed-3-large": EncoderInfo(name="text-embed-3-large", token_limit=8192),
 }
 
 
 class OpenAIEncoder(BaseEncoder):
     client: Optional[openai.Client]
     dimensions: Union[int, NotGiven] = NotGiven()
-    token_limit: Optional[int] = None
-    token_encoder: Optional[Any] = None
+    token_limit: int = 8192  # default value, should be replaced by config
+    _token_encoder: Any = PrivateAttr()
     type: str = "openai"
 
     def __init__(
@@ -71,11 +62,11 @@ class OpenAIEncoder(BaseEncoder):
         if name in model_configs:
             self.token_limit = model_configs[name].token_limit
         # get token encoder
-        self.token_encoder = tiktoken.encoding_for_model(name)
+        self._token_encoder = tiktoken.encoding_for_model(name)
 
     def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
         """Encode a list of text documents into embeddings using OpenAI API.
-        
+
         :param docs: List of text documents to encode.
         :param truncate: Whether to truncate the documents to token limit. If
             False and a document exceeds the token limit, an error will be
@@ -121,15 +112,15 @@ class OpenAIEncoder(BaseEncoder):
 
         embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
         return embeddings
-    
+
     def _truncate(self, text: str) -> str:
-        tokens = self.token_encoder.encode(text)
+        tokens = self._token_encoder.encode(text)
         if len(tokens) > self.token_limit:
             logger.warning(
                 f"Document exceeds token limit: {len(tokens)} > {self.token_limit}"
                 "\nTruncating document..."
             )
-            text = self.token_encoder.decode(tokens[:self.token_limit-1])
-            logger.info(f"Trunc length: {len(self.token_encoder.encode(text))}")
+            text = self._token_encoder.decode(tokens[: self.token_limit - 1])
+            logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}")
             return text
         return text
diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py
index ee1b1fa6..df2bf858 100644
--- a/semantic_router/encoders/zure.py
+++ b/semantic_router/encoders/zure.py
@@ -26,7 +26,7 @@ class AzureOpenAIEncoder(BaseEncoder):
         deployment_name: Optional[str] = None,
         azure_endpoint: Optional[str] = None,
         api_version: Optional[str] = None,
-        model: Optional[str] = None,
+        model: Optional[str] = None,  # TODO we should change to `name` JB
         score_threshold: float = 0.82,
     ):
         name = deployment_name
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index d9781820..02e626f4 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -8,12 +8,12 @@ import numpy as np
 import yaml  # type: ignore
 from tqdm.auto import tqdm
 
-from semantic_router.encoders import BaseEncoder, OpenAIEncoder
+from semantic_router.encoders import AutoEncoder, BaseEncoder, OpenAIEncoder
 from semantic_router.index.base import BaseIndex
 from semantic_router.index.local import LocalIndex
 from semantic_router.llms import BaseLLM, OpenAILLM
 from semantic_router.route import Route
-from semantic_router.schema import Encoder, EncoderType, RouteChoice
+from semantic_router.schema import EncoderType, RouteChoice
 from semantic_router.utils.defaults import EncoderDefault
 from semantic_router.utils.logger import logger
 
@@ -337,18 +337,18 @@ class RouteLayer:
     @classmethod
     def from_json(cls, file_path: str):
         config = LayerConfig.from_file(file_path)
-        encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
+        encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
         return cls(encoder=encoder, routes=config.routes)
 
     @classmethod
     def from_yaml(cls, file_path: str):
         config = LayerConfig.from_file(file_path)
-        encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
+        encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
         return cls(encoder=encoder, routes=config.routes)
 
     @classmethod
     def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None):
-        encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model
+        encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model
         return cls(encoder=encoder, routes=config.routes, index=index)
 
     def add(self, route: Route):
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index daf60813..60f61536 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -5,19 +5,24 @@ from pydantic.v1 import BaseModel
 
 
 class EncoderType(Enum):
-    HUGGINGFACE = "huggingface"
-    FASTEMBED = "fastembed"
-    OPENAI = "openai"
+    AZURE = "azure"
     COHERE = "cohere"
+    OPENAI = "openai"
+    BM25 = "bm25"
+    TFIDF = "tfidf"
+    FASTEMBED = "fastembed"
+    HUGGINGFACE = "huggingface"
     MISTRAL = "mistral"
+    VIT = "vit"
+    CLIP = "clip"
     GOOGLE = "google"
 
 
 class EncoderInfo(BaseModel):
     name: str
-    type: EncoderType
     token_limit: int
 
+
 class RouteChoice(BaseModel):
     name: Optional[str] = None
     function_call: Optional[dict] = None
diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py
index de3594f1..508e9e9e 100644
--- a/tests/unit/encoders/test_openai.py
+++ b/tests/unit/encoders/test_openai.py
@@ -5,6 +5,7 @@ from openai.types.create_embedding_response import Usage
 
 from semantic_router.encoders import OpenAIEncoder
 
+
 @pytest.fixture
 def openai_encoder(mocker):
     mocker.patch("openai.Client")
diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py
index 90f3c949..d0fce781 100644
--- a/tests/unit/test_schema.py
+++ b/tests/unit/test_schema.py
@@ -2,46 +2,10 @@ import pytest
 from pydantic.v1 import ValidationError
 
 from semantic_router.schema import (
-    CohereEncoder,
-    Encoder,
-    EncoderType,
     Message,
-    OpenAIEncoder,
 )
 
 
-class TestEncoderDataclass:
-    def test_encoder_initialization_openai(self, mocker):
-        mocker.patch.dict("os.environ", {"OPENAI_API_KEY": "test"})
-        encoder = Encoder(type="openai", name="test-engine")
-        assert encoder.type == EncoderType.OPENAI
-        assert isinstance(encoder.model, OpenAIEncoder)
-
-    def test_encoder_initialization_cohere(self, mocker):
-        mocker.patch.dict("os.environ", {"COHERE_API_KEY": "test"})
-        encoder = Encoder(type="cohere", name="test-engine")
-        assert encoder.type == EncoderType.COHERE
-        assert isinstance(encoder.model, CohereEncoder)
-
-    def test_encoder_initialization_unsupported_type(self):
-        with pytest.raises(ValueError):
-            Encoder(type="unsupported", name="test-engine")
-
-    def test_encoder_initialization_huggingface(self):
-        with pytest.raises(NotImplementedError):
-            Encoder(type="huggingface", name="test-engine")
-
-    def test_encoder_call_method(self, mocker):
-        mocker.patch.dict("os.environ", {"OPENAI_API_KEY": "test"})
-        mocker.patch(
-            "semantic_router.encoders.openai.OpenAIEncoder.__call__",
-            return_value=[0.1, 0.2, 0.3],
-        )
-        encoder = Encoder(type="openai", name="test-engine")
-        result = encoder(["test"])
-        assert result == [0.1, 0.2, 0.3]
-
-
 class TestMessageDataclass:
     def test_message_creation(self):
         message = Message(role="user", content="Hello!")
-- 
GitLab