diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index a79fa60595315943591ad5ec4eb2a34e39a8c50c..5efc730398a45fc3a9de5f234a6d43a0e37911be 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,3 +1,5 @@ +from typing import List, Optional + from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.bm25 import BM25Encoder from semantic_router.encoders.clip import CLIPEncoder @@ -11,6 +13,7 @@ from semantic_router.encoders.openai import OpenAIEncoder from semantic_router.encoders.tfidf import TfidfEncoder from semantic_router.encoders.vit import VitEncoder from semantic_router.encoders.zure import AzureOpenAIEncoder +from semantic_router.schema import EncoderType __all__ = [ "BaseEncoder", @@ -27,3 +30,45 @@ __all__ = [ "CLIPEncoder", "GoogleEncoder", ] + + +class AutoEncoder: + type: EncoderType + name: Optional[str] + model: BaseEncoder + + def __init__(self, type: str, name: Optional[str]): + self.type = EncoderType(type) + self.name = name + if self.type == EncoderType.AZURE: + # TODO should change `model` to `name` JB + self.model = AzureOpenAIEncoder(model=name) + elif self.type == EncoderType.COHERE: + self.model = CohereEncoder(name=name) + elif self.type == EncoderType.OPENAI: + self.model = OpenAIEncoder(name=name) + elif self.type == EncoderType.BM25: + if name is None: + name = "bm25" + self.model = BM25Encoder(name=name) + elif self.type == EncoderType.TFIDF: + if name is None: + name = "tfidf" + self.model = TfidfEncoder(name=name) + elif self.type == EncoderType.FASTEMBED: + self.model = FastEmbedEncoder(name=name) + elif self.type == EncoderType.HUGGINGFACE: + self.model = HuggingFaceEncoder(name=name) + elif self.type == EncoderType.MISTRAL: + self.model = MistralEncoder(name=name) + elif self.type == EncoderType.VIT: + self.model = VitEncoder(name=name) + elif self.type == EncoderType.CLIP: + self.model = CLIPEncoder(name=name) + elif self.type == EncoderType.GOOGLE: + self.model = GoogleEncoder(name=name) + else: + raise ValueError(f"Encoder type '{type}' not supported") + + def __call__(self, texts: List[str]) -> List[List[float]]: + return self.model(texts) diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 14b9bb7bb4549443373429cb8bfb6a79530a578f..d56a1e71ecd0ff520edd69ff0d709524617db964 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -1,6 +1,7 @@ import os from time import sleep from typing import Any, List, Optional, Union +from pydantic.v1 import PrivateAttr import openai from openai import OpenAIError @@ -16,28 +17,18 @@ from semantic_router.utils.logger import logger model_configs = { "text-embedding-ada-002": EncoderInfo( - name="text-embedding-ada-002", - type="openai", - token_limit=4000 + name="text-embedding-ada-002", token_limit=8192 ), - "text-embed-3-small": EncoderInfo( - name="text-embed-3-small", - type="openai", - token_limit=8192 - ), - "text-embed-3-large": EncoderInfo( - name="text-embed-3-large", - type="openai", - token_limit=8192 - ) + "text-embed-3-small": EncoderInfo(name="text-embed-3-small", token_limit=8192), + "text-embed-3-large": EncoderInfo(name="text-embed-3-large", token_limit=8192), } class OpenAIEncoder(BaseEncoder): client: Optional[openai.Client] dimensions: Union[int, NotGiven] = NotGiven() - token_limit: Optional[int] = None - token_encoder: Optional[Any] = None + token_limit: int = 8192 # default value, should be replaced by config + _token_encoder: Any = PrivateAttr() type: str = "openai" def __init__( @@ -71,11 +62,11 @@ class OpenAIEncoder(BaseEncoder): if name in model_configs: self.token_limit = model_configs[name].token_limit # get token encoder - self.token_encoder = tiktoken.encoding_for_model(name) + self._token_encoder = tiktoken.encoding_for_model(name) def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]: """Encode a list of text documents into embeddings using OpenAI API. - + :param docs: List of text documents to encode. :param truncate: Whether to truncate the documents to token limit. If False and a document exceeds the token limit, an error will be @@ -121,15 +112,15 @@ class OpenAIEncoder(BaseEncoder): embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings - + def _truncate(self, text: str) -> str: - tokens = self.token_encoder.encode(text) + tokens = self._token_encoder.encode(text) if len(tokens) > self.token_limit: logger.warning( f"Document exceeds token limit: {len(tokens)} > {self.token_limit}" "\nTruncating document..." ) - text = self.token_encoder.decode(tokens[:self.token_limit-1]) - logger.info(f"Trunc length: {len(self.token_encoder.encode(text))}") + text = self._token_encoder.decode(tokens[: self.token_limit - 1]) + logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}") return text return text diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index ee1b1fa67fdf917b8c9113b1e89fcd7e82dd6486..df2bf858aa32578e4751fac90d698e4db111f84b 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -26,7 +26,7 @@ class AzureOpenAIEncoder(BaseEncoder): deployment_name: Optional[str] = None, azure_endpoint: Optional[str] = None, api_version: Optional[str] = None, - model: Optional[str] = None, + model: Optional[str] = None, # TODO we should change to `name` JB score_threshold: float = 0.82, ): name = deployment_name diff --git a/semantic_router/layer.py b/semantic_router/layer.py index d9781820bf6bdc7af5c6eac32be87e982a6d04df..02e626f4979b945258b0ef3223e547e05c0218bd 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -8,12 +8,12 @@ import numpy as np import yaml # type: ignore from tqdm.auto import tqdm -from semantic_router.encoders import BaseEncoder, OpenAIEncoder +from semantic_router.encoders import AutoEncoder, BaseEncoder, OpenAIEncoder from semantic_router.index.base import BaseIndex from semantic_router.index.local import LocalIndex from semantic_router.llms import BaseLLM, OpenAILLM from semantic_router.route import Route -from semantic_router.schema import Encoder, EncoderType, RouteChoice +from semantic_router.schema import EncoderType, RouteChoice from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger @@ -337,18 +337,18 @@ class RouteLayer: @classmethod def from_json(cls, file_path: str): config = LayerConfig.from_file(file_path) - encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model + encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model return cls(encoder=encoder, routes=config.routes) @classmethod def from_yaml(cls, file_path: str): config = LayerConfig.from_file(file_path) - encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model + encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model return cls(encoder=encoder, routes=config.routes) @classmethod def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None): - encoder = Encoder(type=config.encoder_type, name=config.encoder_name).model + encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model return cls(encoder=encoder, routes=config.routes, index=index) def add(self, route: Route): diff --git a/semantic_router/schema.py b/semantic_router/schema.py index daf608131f6bc107c47e60644b93073d1fbf8102..60f61536e08840e2c023fd12867cc960a14d3608 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -5,19 +5,24 @@ from pydantic.v1 import BaseModel class EncoderType(Enum): - HUGGINGFACE = "huggingface" - FASTEMBED = "fastembed" - OPENAI = "openai" + AZURE = "azure" COHERE = "cohere" + OPENAI = "openai" + BM25 = "bm25" + TFIDF = "tfidf" + FASTEMBED = "fastembed" + HUGGINGFACE = "huggingface" MISTRAL = "mistral" + VIT = "vit" + CLIP = "clip" GOOGLE = "google" class EncoderInfo(BaseModel): name: str - type: EncoderType token_limit: int + class RouteChoice(BaseModel): name: Optional[str] = None function_call: Optional[dict] = None diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py index de3594f1838ebf787b2a2d78c2738bb7eae772db..508e9e9e197a8a0b87fb179554d43c18686c1b44 100644 --- a/tests/unit/encoders/test_openai.py +++ b/tests/unit/encoders/test_openai.py @@ -5,6 +5,7 @@ from openai.types.create_embedding_response import Usage from semantic_router.encoders import OpenAIEncoder + @pytest.fixture def openai_encoder(mocker): mocker.patch("openai.Client") diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index 90f3c949b5009bdb894367fa3ad5d37a1beed889..d0fce781be2bbf948b43b7a48d3047ab7318b92a 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -2,46 +2,10 @@ import pytest from pydantic.v1 import ValidationError from semantic_router.schema import ( - CohereEncoder, - Encoder, - EncoderType, Message, - OpenAIEncoder, ) -class TestEncoderDataclass: - def test_encoder_initialization_openai(self, mocker): - mocker.patch.dict("os.environ", {"OPENAI_API_KEY": "test"}) - encoder = Encoder(type="openai", name="test-engine") - assert encoder.type == EncoderType.OPENAI - assert isinstance(encoder.model, OpenAIEncoder) - - def test_encoder_initialization_cohere(self, mocker): - mocker.patch.dict("os.environ", {"COHERE_API_KEY": "test"}) - encoder = Encoder(type="cohere", name="test-engine") - assert encoder.type == EncoderType.COHERE - assert isinstance(encoder.model, CohereEncoder) - - def test_encoder_initialization_unsupported_type(self): - with pytest.raises(ValueError): - Encoder(type="unsupported", name="test-engine") - - def test_encoder_initialization_huggingface(self): - with pytest.raises(NotImplementedError): - Encoder(type="huggingface", name="test-engine") - - def test_encoder_call_method(self, mocker): - mocker.patch.dict("os.environ", {"OPENAI_API_KEY": "test"}) - mocker.patch( - "semantic_router.encoders.openai.OpenAIEncoder.__call__", - return_value=[0.1, 0.2, 0.3], - ) - encoder = Encoder(type="openai", name="test-engine") - result = encoder(["test"]) - assert result == [0.1, 0.2, 0.3] - - class TestMessageDataclass: def test_message_creation(self): message = Message(role="user", content="Hello!")