From ae7db0d0de69dd445ed1070703bcd06fa9409500 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Thu, 13 Feb 2025 12:06:06 +0400 Subject: [PATCH] fix: lint fixes --- semantic_router/encoders/aurelio.py | 1 + semantic_router/encoders/base.py | 4 +-- semantic_router/encoders/bedrock.py | 3 +- semantic_router/encoders/clip.py | 1 + semantic_router/encoders/cohere.py | 1 + semantic_router/encoders/fastembed.py | 8 ++--- semantic_router/encoders/huggingface.py | 4 ++- semantic_router/encoders/openai.py | 1 + semantic_router/encoders/vit.py | 1 + semantic_router/encoders/zure.py | 1 + semantic_router/index/postgres.py | 12 +++---- semantic_router/llms/base.py | 1 + semantic_router/llms/cohere.py | 1 + semantic_router/llms/llamacpp.py | 1 + semantic_router/llms/mistral.py | 4 +-- semantic_router/llms/ollama.py | 1 + semantic_router/llms/openai.py | 4 +-- semantic_router/llms/openrouter.py | 1 + semantic_router/llms/zure.py | 4 +-- semantic_router/route.py | 8 +++++ semantic_router/routers/base.py | 12 +++---- semantic_router/routers/semantic.py | 4 +-- semantic_router/schema.py | 44 ++++++++++++------------- semantic_router/utils/defaults.py | 3 +- semantic_router/utils/logger.py | 9 ++--- 25 files changed, 71 insertions(+), 63 deletions(-) diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py index 81e9f8e8..5f39a472 100644 --- a/semantic_router/encoders/aurelio.py +++ b/semantic_router/encoders/aurelio.py @@ -12,6 +12,7 @@ class AurelioSparseEncoder(SparseEncoder): """Sparse encoder using Aurelio Platform's embedding API. Requires an API key from https://platform.aurelio.ai """ + model: Optional[Any] = None client: AurelioClient = Field(default_factory=AurelioClient, exclude=True) async_client: AsyncAurelioClient = Field( diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 864668ee..2442a9ae 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -51,8 +51,8 @@ class DenseEncoder(BaseModel): class SparseEncoder(BaseModel): - """An encoder that encodes documents into a sparse format. - """ + """An encoder that encodes documents into a sparse format.""" + name: str type: str = Field(default="base") diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index 1ff8723f..57779b4e 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -32,7 +32,7 @@ class BedrockEncoder(DenseEncoder): """Dense encoder using Amazon Bedrock embedding API. Requires an AWS Access Key ID and AWS Secret Access Key. - The BedrockEncoder class is a subclass of DenseEncoder and utilizes the + The BedrockEncoder class is a subclass of DenseEncoder and utilizes the TextEmbeddingModel from the Amazon's Bedrock Platform to generate embeddings for given documents. It supports customization of the pre-trained model, score threshold, and region. @@ -50,6 +50,7 @@ class BedrockEncoder(DenseEncoder): embeddings = encoder(["document1", "document2"]) ``` """ + client: Any = None type: str = "bedrock" input_type: Optional[str] = "search_query" diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py index 3dc33b99..5ad351ef 100644 --- a/semantic_router/encoders/clip.py +++ b/semantic_router/encoders/clip.py @@ -31,6 +31,7 @@ class CLIPEncoder(DenseEncoder): :param _Image: The PIL library. :type _Image: Any """ + name: str = "openai/clip-vit-base-patch16" type: str = "huggingface" tokenizer_kwargs: Dict = {} diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index a021e662..0ccbb744 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -11,6 +11,7 @@ class CohereEncoder(DenseEncoder): """Dense encoder that uses Cohere API to embed documents. Supports text only. Requires a Cohere API key from https://dashboard.cohere.com/api-keys. """ + _client: Any = PrivateAttr() _embed_type: Any = PrivateAttr() type: str = "cohere" diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py index db84feb0..3d720fa6 100644 --- a/semantic_router/encoders/fastembed.py +++ b/semantic_router/encoders/fastembed.py @@ -15,6 +15,7 @@ class FastEmbedEncoder(DenseEncoder): :param cache_dir: The directory to cache the embedding model. :param threads: The number of threads to use for the embedding. """ + type: str = "fastembed" name: str = "BAAI/bge-small-en-v1.5" max_length: int = 512 @@ -22,9 +23,7 @@ class FastEmbedEncoder(DenseEncoder): threads: Optional[int] = None _client: Any = PrivateAttr() - def __init__( - self, score_threshold: float = 0.5, **data - ): + def __init__(self, score_threshold: float = 0.5, **data): """Initialize the FastEmbed encoder. :param score_threshold: The threshold for the score of the embedding. @@ -35,8 +34,7 @@ class FastEmbedEncoder(DenseEncoder): self._client = self._initialize_client() def _initialize_client(self): - """Initialize the FastEmbed library. Requires the fastembed package. - """ + """Initialize the FastEmbed library. Requires the fastembed package.""" try: from fastembed import TextEmbedding except ImportError: diff --git a/semantic_router/encoders/huggingface.py b/semantic_router/encoders/huggingface.py index 6e4dbc5e..a62102e5 100644 --- a/semantic_router/encoders/huggingface.py +++ b/semantic_router/encoders/huggingface.py @@ -34,6 +34,7 @@ from semantic_router.utils.logger import logger # TODO: this should support local models, and we should have another class for remote # inference endpoint models + class HuggingFaceEncoder(DenseEncoder): """HuggingFace encoder class for local embedding models. Models can be trained and loaded from private repositories, or from the Huggingface Hub. The class supports @@ -51,6 +52,7 @@ class HuggingFaceEncoder(DenseEncoder): embeddings = encoder(["document1", "document2"]) ``` """ + name: str = "sentence-transformers/all-MiniLM-L6-v2" type: str = "huggingface" tokenizer_kwargs: Dict = {} @@ -180,7 +182,7 @@ class HuggingFaceEncoder(DenseEncoder): class HFEndpointEncoder(DenseEncoder): """HFEndpointEncoder class to embeddings models using Huggingface's inference endpoints. - + The HFEndpointEncoder class is a subclass of DenseEncoder and utilizes a specified Huggingface endpoint to generate embeddings for given documents. It requires the URL of the Huggingface API endpoint and an API key for authentication. The class supports diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 2d808ddf..fc248b34 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -41,6 +41,7 @@ class OpenAIEncoder(DenseEncoder): to generate embeddings for given documents. It requires an OpenAI API key and supports customization of the score threshold for filtering or processing the embeddings. """ + _client: Optional[openai.Client] = PrivateAttr(default=None) _async_client: Optional[openai.AsyncClient] = PrivateAttr(default=None) dimensions: Union[int, NotGiven] = NotGiven() diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index 4be220a9..1888e27a 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -12,6 +12,7 @@ class VitEncoder(DenseEncoder): model via Hugging Face. It supports various image processing and model initialization options. """ + name: str = "google/vit-base-patch16-224" type: str = "huggingface" processor_kwargs: Dict = {} diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index 3822a743..faab1c90 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -19,6 +19,7 @@ class AzureOpenAIEncoder(DenseEncoder): This class provides functionality to encode text documents using the Azure OpenAI API. It supports customization of the score threshold for filtering or processing the embeddings. """ + client: Optional[openai.AzureOpenAI] = None async_client: Optional[openai.AsyncAzureOpenAI] = None dimensions: Union[int, NotGiven] = NotGiven() diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 5700e9a4..dda6149e 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -15,8 +15,7 @@ if TYPE_CHECKING: class MetricPgVecOperatorMap(Enum): - """Enum to map the metric to PostgreSQL vector operators. - """ + """Enum to map the metric to PostgreSQL vector operators.""" cosine = "<=>" dotproduct = "<#>" # inner product @@ -51,8 +50,7 @@ def clean_route_name(route_name: str) -> str: class PostgresIndexRecord(BaseModel): - """Model to represent a record in the Postgres index. - """ + """Model to represent a record in the Postgres index.""" id: str = "" route: str @@ -90,8 +88,7 @@ class PostgresIndexRecord(BaseModel): class PostgresIndex(BaseIndex): - """Postgres implementation of Index. - """ + """Postgres implementation of Index.""" connection_string: Optional[str] = None index_prefix: str = "semantic_router_" @@ -498,7 +495,6 @@ class PostgresIndex(BaseIndex): return count[0] class Config: - """Configuration for the Pydantic BaseModel. - """ + """Configuration for the Pydantic BaseModel.""" arbitrary_types_allowed = True diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 611f1ac6..a3e6c122 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -13,6 +13,7 @@ class BaseLLM(BaseModel): This class provides a base implementation for LLMs. It defines the common configuration and methods for all LLM classes. """ + name: str temperature: Optional[float] = 0.0 max_tokens: Optional[int] = None diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py index b258ac8f..707887e3 100644 --- a/semantic_router/llms/cohere.py +++ b/semantic_router/llms/cohere.py @@ -13,6 +13,7 @@ class CohereLLM(BaseLLM): This class provides functionality to interact with the Cohere API for generating text responses. It extends the BaseLLM class and implements the __call__ method to generate text responses. """ + _client: Any = PrivateAttr() def __init__( diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index d60f85dd..6940907b 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -13,6 +13,7 @@ class LlamaCppLLM(BaseLLM): """LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of dynamic routes. """ + llm: Any grammar: Optional[Any] = None _llama_cpp: Any = PrivateAttr() diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 0d827a58..d66fa557 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -10,8 +10,8 @@ from semantic_router.utils.logger import logger class MistralAILLM(BaseLLM): - """LLM for MistralAI. Requires a MistralAI API key from https://console.mistral.ai/api-keys/ - """ + """LLM for MistralAI. Requires a MistralAI API key from https://console.mistral.ai/api-keys/""" + _client: Any = PrivateAttr() _mistralai: Any = PrivateAttr() diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py index 5a29a360..aa2a38a1 100644 --- a/semantic_router/llms/ollama.py +++ b/semantic_router/llms/ollama.py @@ -11,6 +11,7 @@ class OllamaLLM(BaseLLM): """LLM for Ollama. Enables fully local LLM use, helpful for local implementation of dynamic routes. """ + stream: bool = False def __init__( diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 71cc3322..16912847 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -22,8 +22,8 @@ from semantic_router.utils.logger import logger class OpenAILLM(BaseLLM): - """LLM for OpenAI. Requires an OpenAI API key from https://platform.openai.com/api-keys. - """ + """LLM for OpenAI. Requires an OpenAI API key from https://platform.openai.com/api-keys.""" + _client: Optional[openai.OpenAI] = PrivateAttr(default=None) _async_client: Optional[openai.AsyncOpenAI] = PrivateAttr(default=None) diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py index b1a5a558..52352808 100644 --- a/semantic_router/llms/openrouter.py +++ b/semantic_router/llms/openrouter.py @@ -12,6 +12,7 @@ from semantic_router.utils.logger import logger class OpenRouterLLM(BaseLLM): """LLM for OpenRouter. Requires an OpenRouter API key, see here for more information https://openrouter.ai/docs/api-reference/authentication#using-an-api-key""" + _client: Optional[openai.OpenAI] = PrivateAttr(default=None) _base_url: str = PrivateAttr(default="https://openrouter.ai/api/v1") diff --git a/semantic_router/llms/zure.py b/semantic_router/llms/zure.py index 3512e89a..143ad63b 100644 --- a/semantic_router/llms/zure.py +++ b/semantic_router/llms/zure.py @@ -11,8 +11,8 @@ from semantic_router.utils.logger import logger class AzureOpenAILLM(BaseLLM): - """LLM for Azure OpenAI. Requires an Azure OpenAI API key. - """ + """LLM for Azure OpenAI. Requires an Azure OpenAI API key.""" + _client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None) def __init__( diff --git a/semantic_router/route.py b/semantic_router/route.py index 4294cc72..4df7f1fc 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -64,6 +64,14 @@ class Route(BaseModel): :type metadata: Optional[Dict[str, Any]] """ + name: str + utterances: Union[List[str], List[Any]] + description: Optional[str] = None + function_schemas: Optional[List[Dict[str, Any]]] = None + llm: Optional[BaseLLM] = None + score_threshold: Optional[float] = None + metadata: Optional[Dict[str, Any]] = {} + class Config: arbitrary_types_allowed = True diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index af439133..0689c409 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -65,8 +65,7 @@ def is_valid(layer_config: str) -> bool: class RouterConfig: - """Generates a RouterConfig object that can be used for initializing routers. - """ + """Generates a RouterConfig object that can be used for initializing routers.""" routes: List[Route] = Field(default_factory=list) @@ -348,8 +347,7 @@ def xq_reshape(xq: List[float] | np.ndarray) -> np.ndarray: class BaseRouter(BaseModel): - """Base class for all routers. - """ + """Base class for all routers.""" encoder: DenseEncoder = Field(default_factory=OpenAIEncoder) sparse_encoder: Optional[SparseEncoder] = Field(default=None) @@ -480,8 +478,7 @@ class BaseRouter(BaseModel): ) def _init_index_state(self): - """Initializes an index (where required) and runs auto_sync if active. - """ + """Initializes an index (where required) and runs auto_sync if active.""" # initialize index now, check if we need dimensions if self.index.dimensions is None: dims = len(self.encoder(["test"])[0]) @@ -1283,7 +1280,7 @@ class BaseRouter(BaseModel): :type include_metadata: bool :return: A tuple of the route names, utterances, and function schemas. """ - + route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] function_schemas = [ @@ -1615,6 +1612,7 @@ class BaseRouter(BaseModel): :param local_execution: Whether to execute the fitting locally. :type local_execution: bool """ + original_index = self.index if local_execution: # Switch to a local index for fitting from semantic_router.index.local import LocalIndex diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py index 19d9903c..72b28388 100644 --- a/semantic_router/routers/semantic.py +++ b/semantic_router/routers/semantic.py @@ -11,8 +11,8 @@ from semantic_router.utils.logger import logger class SemanticRouter(BaseRouter): - """A router that uses a dense encoder to encode routes and utterances. - """ + """A router that uses a dense encoder to encode routes and utterances.""" + def __init__( self, encoder: Optional[DenseEncoder] = None, diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 6db78b31..fa6cdd2d 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -12,8 +12,7 @@ from aurelio_sdk.schema import SparseEmbedding as BM25SparseEmbedding class EncoderType(Enum): - """The type of encoder. - """ + """The type of encoder.""" AURELIO = "aurelio" AZURE = "azure" @@ -31,52 +30,49 @@ class EncoderType(Enum): class EncoderInfo(BaseModel): - """Information about an encoder. - """ + """Information about an encoder.""" + name: str token_limit: int threshold: Optional[float] = None class RouteChoice(BaseModel): - """A route choice typically output by the routers. - """ + """A route choice typically output by the routers.""" + name: Optional[str] = None function_call: Optional[List[Dict]] = None similarity_score: Optional[float] = None class Message(BaseModel): - """A message in a conversation, includes the role and content fields. - """ + """A message in a conversation, includes the role and content fields.""" + role: str content: str def to_openai(self): - """Convert the message to an OpenAI-compatible format. - """ + """Convert the message to an OpenAI-compatible format.""" if self.role.lower() not in ["user", "assistant", "system", "tool"]: - raise ValueError("Role must be either 'user', 'assistant', 'system' or 'tool'") + raise ValueError( + "Role must be either 'user', 'assistant', 'system' or 'tool'" + ) return {"role": self.role, "content": self.content} def to_cohere(self): - """Convert the message to a Cohere-compatible format. - """ + """Convert the message to a Cohere-compatible format.""" return {"role": self.role, "message": self.content} def to_llamacpp(self): - """Convert the message to a LlamaCPP-compatible format. - """ + """Convert the message to a LlamaCPP-compatible format.""" return {"role": self.role, "content": self.content} def to_mistral(self): - """Convert the message to a Mistral-compatible format. - """ + """Convert the message to a Mistral-compatible format.""" return {"role": self.role, "content": self.content} def __str__(self): - """Convert the message to a string. - """ + """Convert the message to a string.""" return f"{self.role}: {self.content}" @@ -84,6 +80,7 @@ class ConfigParameter(BaseModel): """A configuration parameter for a route. Used for remote router metadata such as router hashes, sync locks, etc. """ + field: str value: str scope: Optional[str] = None @@ -118,6 +115,7 @@ class Utterance(BaseModel): """An utterance in a conversation, includes the route, utterance, function schemas, metadata, and diff tag. """ + route: str utterance: Union[str, Any] function_schemas: Optional[List[Dict]] = None @@ -191,8 +189,7 @@ class Utterance(BaseModel): class SyncMode(Enum): - """Synchronization modes for local (route layer) and remote (index) instances. - """ + """Synchronization modes for local (route layer) and remote (index) instances.""" ERROR = "error" REMOTE = "remote" @@ -209,6 +206,7 @@ class UtteranceDiff(BaseModel): """A list of Utterance objects that represent the differences between local and remote utterances. """ + diff: List[Utterance] @classmethod @@ -471,8 +469,8 @@ class UtteranceDiff(BaseModel): class Metric(Enum): - """The metric to use in vector-based similarity search indexes. - """ + """The metric to use in vector-based similarity search indexes.""" + COSINE = "cosine" DOTPRODUCT = "dotproduct" EUCLIDEAN = "euclidean" diff --git a/semantic_router/utils/defaults.py b/semantic_router/utils/defaults.py index 90bfb1b0..7915ac9d 100644 --- a/semantic_router/utils/defaults.py +++ b/semantic_router/utils/defaults.py @@ -3,8 +3,7 @@ from enum import Enum class EncoderDefault(Enum): - """Default model names for each encoder type. - """ + """Default model names for each encoder type.""" FASTEMBED = { "embedding_model": "BAAI/bge-small-en-v1.5", diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py index 634dbc8c..30aedc2c 100644 --- a/semantic_router/utils/logger.py +++ b/semantic_router/utils/logger.py @@ -4,8 +4,7 @@ import colorlog class CustomFormatter(colorlog.ColoredFormatter): - """Custom formatter for the logger. - """ + """Custom formatter for the logger.""" def __init__(self): super().__init__( @@ -24,8 +23,7 @@ class CustomFormatter(colorlog.ColoredFormatter): def add_coloured_handler(logger): - """Add a coloured handler to the logger. - """ + """Add a coloured handler to the logger.""" formatter = CustomFormatter() console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) @@ -34,8 +32,7 @@ def add_coloured_handler(logger): def setup_custom_logger(name): - """Setup a custom logger. - """ + """Setup a custom logger.""" logger = logging.getLogger(name) if not logger.hasHandlers(): -- GitLab