diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py index 81e9f8e8a81750e91314609f709ca31b4279ee81..5f39a4728aff3841a5f49491c580a18d4f3d1b55 100644 --- a/semantic_router/encoders/aurelio.py +++ b/semantic_router/encoders/aurelio.py @@ -12,6 +12,7 @@ class AurelioSparseEncoder(SparseEncoder): """Sparse encoder using Aurelio Platform's embedding API. Requires an API key from https://platform.aurelio.ai """ + model: Optional[Any] = None client: AurelioClient = Field(default_factory=AurelioClient, exclude=True) async_client: AsyncAurelioClient = Field( diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 864668ee405ac7b582e61c774c26b6eea561ab85..2442a9aeab34587a06a0cbb29a06c84c226f6817 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -51,8 +51,8 @@ class DenseEncoder(BaseModel): class SparseEncoder(BaseModel): - """An encoder that encodes documents into a sparse format. - """ + """An encoder that encodes documents into a sparse format.""" + name: str type: str = Field(default="base") diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index 1ff8723ff89f3ae564994f2505e59d0166d2d190..57779b4e2fcc7bbaae613657030235a1cd31bcd8 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -32,7 +32,7 @@ class BedrockEncoder(DenseEncoder): """Dense encoder using Amazon Bedrock embedding API. Requires an AWS Access Key ID and AWS Secret Access Key. - The BedrockEncoder class is a subclass of DenseEncoder and utilizes the + The BedrockEncoder class is a subclass of DenseEncoder and utilizes the TextEmbeddingModel from the Amazon's Bedrock Platform to generate embeddings for given documents. It supports customization of the pre-trained model, score threshold, and region. @@ -50,6 +50,7 @@ class BedrockEncoder(DenseEncoder): embeddings = encoder(["document1", "document2"]) ``` """ + client: Any = None type: str = "bedrock" input_type: Optional[str] = "search_query" diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py index 3dc33b9929a9758dfd1bf5550d49943e9c67e2cf..5ad351ef06cebb76e503fa15e1268152af910786 100644 --- a/semantic_router/encoders/clip.py +++ b/semantic_router/encoders/clip.py @@ -31,6 +31,7 @@ class CLIPEncoder(DenseEncoder): :param _Image: The PIL library. :type _Image: Any """ + name: str = "openai/clip-vit-base-patch16" type: str = "huggingface" tokenizer_kwargs: Dict = {} diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index a021e6623fcc586229ac6a93161b29fba29c0a69..0ccbb744415b3f5eaf81c76eb377175c088da62c 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -11,6 +11,7 @@ class CohereEncoder(DenseEncoder): """Dense encoder that uses Cohere API to embed documents. Supports text only. Requires a Cohere API key from https://dashboard.cohere.com/api-keys. """ + _client: Any = PrivateAttr() _embed_type: Any = PrivateAttr() type: str = "cohere" diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py index db84feb056a4007172341e16390687068a83fd92..3d720fa64c0515d3bce89e29c3b3735c184c0f2a 100644 --- a/semantic_router/encoders/fastembed.py +++ b/semantic_router/encoders/fastembed.py @@ -15,6 +15,7 @@ class FastEmbedEncoder(DenseEncoder): :param cache_dir: The directory to cache the embedding model. :param threads: The number of threads to use for the embedding. """ + type: str = "fastembed" name: str = "BAAI/bge-small-en-v1.5" max_length: int = 512 @@ -22,9 +23,7 @@ class FastEmbedEncoder(DenseEncoder): threads: Optional[int] = None _client: Any = PrivateAttr() - def __init__( - self, score_threshold: float = 0.5, **data - ): + def __init__(self, score_threshold: float = 0.5, **data): """Initialize the FastEmbed encoder. :param score_threshold: The threshold for the score of the embedding. @@ -35,8 +34,7 @@ class FastEmbedEncoder(DenseEncoder): self._client = self._initialize_client() def _initialize_client(self): - """Initialize the FastEmbed library. Requires the fastembed package. - """ + """Initialize the FastEmbed library. Requires the fastembed package.""" try: from fastembed import TextEmbedding except ImportError: diff --git a/semantic_router/encoders/huggingface.py b/semantic_router/encoders/huggingface.py index 6e4dbc5ebddfaeb0abf831599c383b5b74cd54e2..a62102e5f72960d8c6e3f141eb43f01c48689577 100644 --- a/semantic_router/encoders/huggingface.py +++ b/semantic_router/encoders/huggingface.py @@ -34,6 +34,7 @@ from semantic_router.utils.logger import logger # TODO: this should support local models, and we should have another class for remote # inference endpoint models + class HuggingFaceEncoder(DenseEncoder): """HuggingFace encoder class for local embedding models. Models can be trained and loaded from private repositories, or from the Huggingface Hub. The class supports @@ -51,6 +52,7 @@ class HuggingFaceEncoder(DenseEncoder): embeddings = encoder(["document1", "document2"]) ``` """ + name: str = "sentence-transformers/all-MiniLM-L6-v2" type: str = "huggingface" tokenizer_kwargs: Dict = {} @@ -180,7 +182,7 @@ class HuggingFaceEncoder(DenseEncoder): class HFEndpointEncoder(DenseEncoder): """HFEndpointEncoder class to embeddings models using Huggingface's inference endpoints. - + The HFEndpointEncoder class is a subclass of DenseEncoder and utilizes a specified Huggingface endpoint to generate embeddings for given documents. It requires the URL of the Huggingface API endpoint and an API key for authentication. The class supports diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 2d808ddf42ccaab9fe8e40829cacc0a9a667e462..fc248b3467a45660550ac763801821ce269fad80 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -41,6 +41,7 @@ class OpenAIEncoder(DenseEncoder): to generate embeddings for given documents. It requires an OpenAI API key and supports customization of the score threshold for filtering or processing the embeddings. """ + _client: Optional[openai.Client] = PrivateAttr(default=None) _async_client: Optional[openai.AsyncClient] = PrivateAttr(default=None) dimensions: Union[int, NotGiven] = NotGiven() diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index 4be220a93890610cef45e757aebbeaf83362f19b..1888e27a66a77fada138ef363bc08d32f68a0d13 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -12,6 +12,7 @@ class VitEncoder(DenseEncoder): model via Hugging Face. It supports various image processing and model initialization options. """ + name: str = "google/vit-base-patch16-224" type: str = "huggingface" processor_kwargs: Dict = {} diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index 3822a743894a1653cacfaa2c11793c9c910380d2..faab1c90d59336980fa9509620dd9ed2d70502ff 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -19,6 +19,7 @@ class AzureOpenAIEncoder(DenseEncoder): This class provides functionality to encode text documents using the Azure OpenAI API. It supports customization of the score threshold for filtering or processing the embeddings. """ + client: Optional[openai.AzureOpenAI] = None async_client: Optional[openai.AsyncAzureOpenAI] = None dimensions: Union[int, NotGiven] = NotGiven() diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 5700e9a40b13f341e6fa2d8066e0b0b00c67b71a..dda6149e626ea1a2b9af7ee515a1e99c6b52b1d1 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -15,8 +15,7 @@ if TYPE_CHECKING: class MetricPgVecOperatorMap(Enum): - """Enum to map the metric to PostgreSQL vector operators. - """ + """Enum to map the metric to PostgreSQL vector operators.""" cosine = "<=>" dotproduct = "<#>" # inner product @@ -51,8 +50,7 @@ def clean_route_name(route_name: str) -> str: class PostgresIndexRecord(BaseModel): - """Model to represent a record in the Postgres index. - """ + """Model to represent a record in the Postgres index.""" id: str = "" route: str @@ -90,8 +88,7 @@ class PostgresIndexRecord(BaseModel): class PostgresIndex(BaseIndex): - """Postgres implementation of Index. - """ + """Postgres implementation of Index.""" connection_string: Optional[str] = None index_prefix: str = "semantic_router_" @@ -498,7 +495,6 @@ class PostgresIndex(BaseIndex): return count[0] class Config: - """Configuration for the Pydantic BaseModel. - """ + """Configuration for the Pydantic BaseModel.""" arbitrary_types_allowed = True diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 611f1ac62d39901009185e6ecfc52836a44690be..a3e6c12232596da5bf49d60b3a6b25a8b6fcb70f 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -13,6 +13,7 @@ class BaseLLM(BaseModel): This class provides a base implementation for LLMs. It defines the common configuration and methods for all LLM classes. """ + name: str temperature: Optional[float] = 0.0 max_tokens: Optional[int] = None diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py index b258ac8f111778efa2e743168cf7e6d5efaf2a2e..707887e389ba360b03b456a61419f4461e52f8ed 100644 --- a/semantic_router/llms/cohere.py +++ b/semantic_router/llms/cohere.py @@ -13,6 +13,7 @@ class CohereLLM(BaseLLM): This class provides functionality to interact with the Cohere API for generating text responses. It extends the BaseLLM class and implements the __call__ method to generate text responses. """ + _client: Any = PrivateAttr() def __init__( diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index d60f85dd75a4ee30df5469966346eccc325f13b2..6940907bfad0de63466b9a624d3a7f15050dbb82 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -13,6 +13,7 @@ class LlamaCppLLM(BaseLLM): """LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of dynamic routes. """ + llm: Any grammar: Optional[Any] = None _llama_cpp: Any = PrivateAttr() diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 0d827a58ff4ba02a1bbcaea45fa6d583acdf2b62..d66fa5576541418c80879d530ef723be6d98a376 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -10,8 +10,8 @@ from semantic_router.utils.logger import logger class MistralAILLM(BaseLLM): - """LLM for MistralAI. Requires a MistralAI API key from https://console.mistral.ai/api-keys/ - """ + """LLM for MistralAI. Requires a MistralAI API key from https://console.mistral.ai/api-keys/""" + _client: Any = PrivateAttr() _mistralai: Any = PrivateAttr() diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py index 5a29a36021d6e71cc02a7d00ab97efea93d05d74..aa2a38a132d087bb8b3da32fafc9ba18d950ec52 100644 --- a/semantic_router/llms/ollama.py +++ b/semantic_router/llms/ollama.py @@ -11,6 +11,7 @@ class OllamaLLM(BaseLLM): """LLM for Ollama. Enables fully local LLM use, helpful for local implementation of dynamic routes. """ + stream: bool = False def __init__( diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 71cc3322c8d2e7961cec5d3ebe1fee750e5cf044..169128477eb330531e52d8522bf6002c8bd6760b 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -22,8 +22,8 @@ from semantic_router.utils.logger import logger class OpenAILLM(BaseLLM): - """LLM for OpenAI. Requires an OpenAI API key from https://platform.openai.com/api-keys. - """ + """LLM for OpenAI. Requires an OpenAI API key from https://platform.openai.com/api-keys.""" + _client: Optional[openai.OpenAI] = PrivateAttr(default=None) _async_client: Optional[openai.AsyncOpenAI] = PrivateAttr(default=None) diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py index b1a5a558d0fcf655b3ffb855b55a08658ee1b017..523528089425fa13b71cb7fdbfd8c278eea71840 100644 --- a/semantic_router/llms/openrouter.py +++ b/semantic_router/llms/openrouter.py @@ -12,6 +12,7 @@ from semantic_router.utils.logger import logger class OpenRouterLLM(BaseLLM): """LLM for OpenRouter. Requires an OpenRouter API key, see here for more information https://openrouter.ai/docs/api-reference/authentication#using-an-api-key""" + _client: Optional[openai.OpenAI] = PrivateAttr(default=None) _base_url: str = PrivateAttr(default="https://openrouter.ai/api/v1") diff --git a/semantic_router/llms/zure.py b/semantic_router/llms/zure.py index 3512e89afd6e90e3cad692f8db9fb776945f4696..143ad63b6195c70291c6d2b3940e5138eedfd857 100644 --- a/semantic_router/llms/zure.py +++ b/semantic_router/llms/zure.py @@ -11,8 +11,8 @@ from semantic_router.utils.logger import logger class AzureOpenAILLM(BaseLLM): - """LLM for Azure OpenAI. Requires an Azure OpenAI API key. - """ + """LLM for Azure OpenAI. Requires an Azure OpenAI API key.""" + _client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None) def __init__( diff --git a/semantic_router/route.py b/semantic_router/route.py index 4294cc72b0d5ceeac46f776be7a6e3e948f7d0ac..4df7f1fcf70d9daca719d6ef8db58eeb8b5e632c 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -64,6 +64,14 @@ class Route(BaseModel): :type metadata: Optional[Dict[str, Any]] """ + name: str + utterances: Union[List[str], List[Any]] + description: Optional[str] = None + function_schemas: Optional[List[Dict[str, Any]]] = None + llm: Optional[BaseLLM] = None + score_threshold: Optional[float] = None + metadata: Optional[Dict[str, Any]] = {} + class Config: arbitrary_types_allowed = True diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index af43913317cb0b14b45378919ee5ef01cee57a40..0689c4093c86b1b200a0de97eb352d1518689f1c 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -65,8 +65,7 @@ def is_valid(layer_config: str) -> bool: class RouterConfig: - """Generates a RouterConfig object that can be used for initializing routers. - """ + """Generates a RouterConfig object that can be used for initializing routers.""" routes: List[Route] = Field(default_factory=list) @@ -348,8 +347,7 @@ def xq_reshape(xq: List[float] | np.ndarray) -> np.ndarray: class BaseRouter(BaseModel): - """Base class for all routers. - """ + """Base class for all routers.""" encoder: DenseEncoder = Field(default_factory=OpenAIEncoder) sparse_encoder: Optional[SparseEncoder] = Field(default=None) @@ -480,8 +478,7 @@ class BaseRouter(BaseModel): ) def _init_index_state(self): - """Initializes an index (where required) and runs auto_sync if active. - """ + """Initializes an index (where required) and runs auto_sync if active.""" # initialize index now, check if we need dimensions if self.index.dimensions is None: dims = len(self.encoder(["test"])[0]) @@ -1283,7 +1280,7 @@ class BaseRouter(BaseModel): :type include_metadata: bool :return: A tuple of the route names, utterances, and function schemas. """ - + route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] function_schemas = [ @@ -1615,6 +1612,7 @@ class BaseRouter(BaseModel): :param local_execution: Whether to execute the fitting locally. :type local_execution: bool """ + original_index = self.index if local_execution: # Switch to a local index for fitting from semantic_router.index.local import LocalIndex diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py index 19d9903c5cc937cb3f5ada7d3d8405dbf8d4454e..72b28388a5ca6911e6cf92137d29cd11b4ec1dcb 100644 --- a/semantic_router/routers/semantic.py +++ b/semantic_router/routers/semantic.py @@ -11,8 +11,8 @@ from semantic_router.utils.logger import logger class SemanticRouter(BaseRouter): - """A router that uses a dense encoder to encode routes and utterances. - """ + """A router that uses a dense encoder to encode routes and utterances.""" + def __init__( self, encoder: Optional[DenseEncoder] = None, diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 6db78b3118f1102d448c56c4ecf2c1d54e475520..fa6cdd2d40f50a103c69dc7bc2e0039d47f20d50 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -12,8 +12,7 @@ from aurelio_sdk.schema import SparseEmbedding as BM25SparseEmbedding class EncoderType(Enum): - """The type of encoder. - """ + """The type of encoder.""" AURELIO = "aurelio" AZURE = "azure" @@ -31,52 +30,49 @@ class EncoderType(Enum): class EncoderInfo(BaseModel): - """Information about an encoder. - """ + """Information about an encoder.""" + name: str token_limit: int threshold: Optional[float] = None class RouteChoice(BaseModel): - """A route choice typically output by the routers. - """ + """A route choice typically output by the routers.""" + name: Optional[str] = None function_call: Optional[List[Dict]] = None similarity_score: Optional[float] = None class Message(BaseModel): - """A message in a conversation, includes the role and content fields. - """ + """A message in a conversation, includes the role and content fields.""" + role: str content: str def to_openai(self): - """Convert the message to an OpenAI-compatible format. - """ + """Convert the message to an OpenAI-compatible format.""" if self.role.lower() not in ["user", "assistant", "system", "tool"]: - raise ValueError("Role must be either 'user', 'assistant', 'system' or 'tool'") + raise ValueError( + "Role must be either 'user', 'assistant', 'system' or 'tool'" + ) return {"role": self.role, "content": self.content} def to_cohere(self): - """Convert the message to a Cohere-compatible format. - """ + """Convert the message to a Cohere-compatible format.""" return {"role": self.role, "message": self.content} def to_llamacpp(self): - """Convert the message to a LlamaCPP-compatible format. - """ + """Convert the message to a LlamaCPP-compatible format.""" return {"role": self.role, "content": self.content} def to_mistral(self): - """Convert the message to a Mistral-compatible format. - """ + """Convert the message to a Mistral-compatible format.""" return {"role": self.role, "content": self.content} def __str__(self): - """Convert the message to a string. - """ + """Convert the message to a string.""" return f"{self.role}: {self.content}" @@ -84,6 +80,7 @@ class ConfigParameter(BaseModel): """A configuration parameter for a route. Used for remote router metadata such as router hashes, sync locks, etc. """ + field: str value: str scope: Optional[str] = None @@ -118,6 +115,7 @@ class Utterance(BaseModel): """An utterance in a conversation, includes the route, utterance, function schemas, metadata, and diff tag. """ + route: str utterance: Union[str, Any] function_schemas: Optional[List[Dict]] = None @@ -191,8 +189,7 @@ class Utterance(BaseModel): class SyncMode(Enum): - """Synchronization modes for local (route layer) and remote (index) instances. - """ + """Synchronization modes for local (route layer) and remote (index) instances.""" ERROR = "error" REMOTE = "remote" @@ -209,6 +206,7 @@ class UtteranceDiff(BaseModel): """A list of Utterance objects that represent the differences between local and remote utterances. """ + diff: List[Utterance] @classmethod @@ -471,8 +469,8 @@ class UtteranceDiff(BaseModel): class Metric(Enum): - """The metric to use in vector-based similarity search indexes. - """ + """The metric to use in vector-based similarity search indexes.""" + COSINE = "cosine" DOTPRODUCT = "dotproduct" EUCLIDEAN = "euclidean" diff --git a/semantic_router/utils/defaults.py b/semantic_router/utils/defaults.py index 90bfb1b0c7ed843255d48056d10dd8e824ed8621..7915ac9d7125582fdc6c0be871cacd0cf9abd71e 100644 --- a/semantic_router/utils/defaults.py +++ b/semantic_router/utils/defaults.py @@ -3,8 +3,7 @@ from enum import Enum class EncoderDefault(Enum): - """Default model names for each encoder type. - """ + """Default model names for each encoder type.""" FASTEMBED = { "embedding_model": "BAAI/bge-small-en-v1.5", diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py index 634dbc8c714edc5e11b464f3175cb109e09a9c45..30aedc2c111c9c3cd3abc8c7c6aec2793a8d5ff3 100644 --- a/semantic_router/utils/logger.py +++ b/semantic_router/utils/logger.py @@ -4,8 +4,7 @@ import colorlog class CustomFormatter(colorlog.ColoredFormatter): - """Custom formatter for the logger. - """ + """Custom formatter for the logger.""" def __init__(self): super().__init__( @@ -24,8 +23,7 @@ class CustomFormatter(colorlog.ColoredFormatter): def add_coloured_handler(logger): - """Add a coloured handler to the logger. - """ + """Add a coloured handler to the logger.""" formatter = CustomFormatter() console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) @@ -34,8 +32,7 @@ def add_coloured_handler(logger): def setup_custom_logger(name): - """Setup a custom logger. - """ + """Setup a custom logger.""" logger = logging.getLogger(name) if not logger.hasHandlers():