From f16f620fb63f7b3264a6da275239441d88e74183 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Thu, 13 Feb 2025 10:22:03 +0400
Subject: [PATCH] chore: add docstrings for few encoders

---
 semantic_router/encoders/aurelio.py |  33 +++++++
 semantic_router/encoders/base.py    |  54 ++++++++++-
 semantic_router/encoders/bedrock.py | 139 ++++++++++++++++------------
 semantic_router/encoders/clip.py    |  71 ++++++++++++++
 4 files changed, 236 insertions(+), 61 deletions(-)

diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py
index d1d40aa7..81e9f8e8 100644
--- a/semantic_router/encoders/aurelio.py
+++ b/semantic_router/encoders/aurelio.py
@@ -9,6 +9,9 @@ from semantic_router.schema import SparseEmbedding
 
 
 class AurelioSparseEncoder(SparseEncoder):
+    """Sparse encoder using Aurelio Platform's embedding API. Requires an API key from
+    https://platform.aurelio.ai
+    """
     model: Optional[Any] = None
     client: AurelioClient = Field(default_factory=AurelioClient, exclude=True)
     async_client: AsyncAurelioClient = Field(
@@ -21,6 +24,13 @@ class AurelioSparseEncoder(SparseEncoder):
         name: str | None = None,
         api_key: Optional[str] = None,
     ):
+        """Initialize the AurelioSparseEncoder.
+
+        :param name: The name of the model to use.
+        :type name: str | None
+        :param api_key: The API key to use.
+        :type api_key: str | None
+        """
         if name is None:
             name = "bm25"
         super().__init__(name=name)
@@ -32,11 +42,28 @@ class AurelioSparseEncoder(SparseEncoder):
         self.async_client = AsyncAurelioClient(api_key=api_key)
 
     def __call__(self, docs: list[str]) -> list[SparseEmbedding]:
+        """Encode a list of documents using the Aurelio Platform embedding API. Documents
+        must be strings, sparse encoders do not support other types.
+
+        :param docs: The documents to encode.
+        :type docs: list[str]
+        :return: The encoded documents.
+        :rtype: list[SparseEmbedding]
+        """
         res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name)
         embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
         return embeds
 
     async def acall(self, docs: list[str]) -> list[SparseEmbedding]:
+        """Asynchronously encode a list of documents using the Aurelio Platform
+        embedding API. Documents must be strings, sparse encoders do not support other
+        types.
+
+        :param docs: The documents to encode.
+        :type docs: list[str]
+        :return: The encoded documents.
+        :rtype: list[SparseEmbedding]
+        """
         res: EmbeddingResponse = await self.async_client.embedding(
             input=docs, model=self.name
         )
@@ -44,4 +71,10 @@ class AurelioSparseEncoder(SparseEncoder):
         return embeds
 
     def fit(self, docs: List[str]):
+        """Fit the encoder to a list of documents. AurelioSparseEncoder does not support
+        fit yet.
+
+        :param docs: The documents to fit the encoder to.
+        :type docs: list[str]
+        """
         raise NotImplementedError("AurelioSparseEncoder does not support fit.")
diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index 3e8cba21..864668ee 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -15,17 +15,44 @@ class DenseEncoder(BaseModel):
         arbitrary_types_allowed = True
 
     @field_validator("score_threshold")
-    def set_score_threshold(cls, v):
+    def set_score_threshold(cls, v: float | None) -> float | None:
+        """Set the score threshold. If None, the score threshold is not used.
+
+        :param v: The score threshold.
+        :type v: float | None
+        :return: The score threshold.
+        :rtype: float | None
+        """
         return float(v) if v is not None else None
 
     def __call__(self, docs: List[Any]) -> List[List[float]]:
+        """Encode a list of documents. Documents can be any type, but the encoder must
+        be built to handle that data type. Typically, these types are strings or
+        arrays representing images.
+
+        :param docs: The documents to encode.
+        :type docs: List[Any]
+        :return: The encoded documents.
+        :rtype: List[List[float]]
+        """
         raise NotImplementedError("Subclasses must implement this method")
 
     def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]:
+        """Encode a list of documents asynchronously. Documents can be any type, but the
+        encoder must be built to handle that data type. Typically, these types are
+        strings or arrays representing images.
+
+        :param docs: The documents to encode.
+        :type docs: List[Any]
+        :return: The encoded documents.
+        :rtype: List[List[float]]
+        """
         raise NotImplementedError("Subclasses must implement this method")
 
 
 class SparseEncoder(BaseModel):
+    """An encoder that encodes documents into a sparse format.
+    """
     name: str
     type: str = Field(default="base")
 
@@ -33,15 +60,38 @@ class SparseEncoder(BaseModel):
         arbitrary_types_allowed = True
 
     def __call__(self, docs: List[str]) -> List[SparseEmbedding]:
+        """Encode a list of documents. Documents must be strings, sparse encoders do not
+        support other types.
+
+        :param docs: The documents to encode.
+        :type docs: List[str]
+        :return: The encoded documents.
+        :rtype: List[SparseEmbedding]
+        """
         raise NotImplementedError("Subclasses must implement this method")
 
     async def acall(self, docs: List[str]) -> list[SparseEmbedding]:
+        """Encode a list of documents. Documents must be strings, sparse encoders do not
+        support other types.
+
+        :param docs: The documents to encode.
+        :type docs: List[str]
+        :return: The encoded documents.
+        :rtype: List[SparseEmbedding]
+        """
         raise NotImplementedError("Subclasses must implement this method")
 
     def _array_to_sparse_embeddings(
         self, sparse_arrays: np.ndarray
     ) -> List[SparseEmbedding]:
-        """Consumes several sparse vectors containing zero-values and returns a compact array."""
+        """Consumes several sparse vectors containing zero-values and returns a compact
+        array.
+
+        :param sparse_arrays: The sparse arrays to compact.
+        :type sparse_arrays: np.ndarray
+        :return: The compact array.
+        :rtype: List[SparseEmbedding]
+        """
         if sparse_arrays.ndim != 2:
             raise ValueError(f"Expected a 2D array, got a {sparse_arrays.ndim}D array.")
         # get coordinates of non-zero values
diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py
index f200bea8..1ff8723f 100644
--- a/semantic_router/encoders/bedrock.py
+++ b/semantic_router/encoders/bedrock.py
@@ -29,6 +29,27 @@ from semantic_router.utils.logger import logger
 
 
 class BedrockEncoder(DenseEncoder):
+    """Dense encoder using Amazon Bedrock embedding API. Requires an AWS Access Key ID
+    and AWS Secret Access Key.
+
+    The BedrockEncoder class is a subclass of DenseEncoder and utilizes the 
+    TextEmbeddingModel from the Amazon's Bedrock Platform to generate embeddings for
+    given documents. It supports customization of the pre-trained model, score
+    threshold, and region.
+
+    Example usage:
+
+    ```python
+    from semantic_router.encoders.bedrock_encoder import BedrockEncoder
+
+    encoder = BedrockEncoder(
+        access_key_id="your-access-key-id",
+        secret_access_key="your-secret-key",
+        region="your-region"
+    )
+    embeddings = encoder(["document1", "document2"])
+    ```
+    """
     client: Any = None
     type: str = "bedrock"
     input_type: Optional[str] = "search_query"
@@ -50,26 +71,32 @@ class BedrockEncoder(DenseEncoder):
     ):
         """Initializes the BedrockEncoder.
 
-        Args:
-            name: The name of the pre-trained model to use for embedding.
-                If not provided, the default model specified in EncoderDefault will
-                be used.
-            score_threshold: The threshold for similarity scores.
-            access_key_id: The AWS access key id for an IAM principle.
-                If not provided, it will be retrieved from the access_key_id
-                environment variable.
-            secret_access_key: The secret access key for an IAM principle.
-                If not provided, it will be retrieved from the AWS_SECRET_KEY
-                environment variable.
-            session_token: The session token for an IAM principle.
-                If not provided, it will be retrieved from the AWS_SESSION_TOKEN
-                environment variable.
-            region: The location of the Bedrock resources.
-                If not provided, it will be retrieved from the AWS_REGION
-                environment variable, defaulting to "us-west-1"
-
-        Raises:
-            ValueError: If the Bedrock Platform client fails to initialize.
+        :param name: The name of the pre-trained model to use for embedding.
+            If not provided, the default model specified in EncoderDefault will
+            be used.
+        :type name: str
+        :param input_type: The type of input to use for the embedding.
+            If not provided, the default input type specified in EncoderDefault will
+            be used.
+        :type input_type: str
+        :param score_threshold: The threshold for similarity scores.
+        :type score_threshold: float
+        :param access_key_id: The AWS access key id for an IAM principle.
+            If not provided, it will be retrieved from the access_key_id
+            environment variable.
+        :type access_key_id: str
+        :param secret_access_key: The secret access key for an IAM principle.
+            If not provided, it will be retrieved from the AWS_SECRET_KEY
+            environment variable.
+        :type secret_access_key: str
+        :param session_token: The session token for an IAM principle.
+            If not provided, it will be retrieved from the AWS_SESSION_TOKEN
+            environment variable.
+        :param region: The location of the Bedrock resources.
+            If not provided, it will be retrieved from the AWS_REGION
+            environment variable, defaulting to "us-west-1"
+        :type region: str
+        :raises ValueError: If the Bedrock Platform client fails to initialize.
         """
         super().__init__(name=name, score_threshold=score_threshold)
         self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id)
@@ -96,16 +123,15 @@ class BedrockEncoder(DenseEncoder):
     ):
         """Initializes the Bedrock client.
 
-        Args:
-            access_key_id: The Amazon access key ID.
-            secret_access_key: The Amazon secret key.
-            region: The location of the AI Platform resources.
-
-        Returns:
-            An instance of the TextEmbeddingModel client.
-
-        Raises:
-            ImportError: If the required Bedrock libraries are not
+        :param access_key_id: The Amazon access key ID.
+        :type access_key_id: str
+        :param secret_access_key: The Amazon secret key.
+        :type secret_access_key: str
+        :param region: The location of the AI Platform resources.
+        :type region: str
+        :returns: An instance of the TextEmbeddingModel client.
+        :rtype: Any
+        :raises ImportError: If the required Bedrock libraries are not
             installed.
             ValueError: If the Bedrock client fails to initialize.
         """
@@ -145,16 +171,14 @@ class BedrockEncoder(DenseEncoder):
     ) -> List[List[float]]:
         """Generates embeddings for the given documents.
 
-        Args:
-            docs: A list of strings representing the documents to embed.
-            model_kwargs: A dictionary of model-specific inference parameters.
-
-        Returns:
-            A list of lists, where each inner list contains the embedding values for a
+        :param docs: A list of strings representing the documents to embed.
+        :type docs: list[str]
+        :param model_kwargs: A dictionary of model-specific inference parameters.
+        :type model_kwargs: dict
+        :returns: A list of lists, where each inner list contains the embedding values for a
             document.
-
-        Raises:
-            ValueError: If the Bedrock Platform client is not initialized or if the
+        :rtype: list[list[float]]
+        :raises ValueError: If the Bedrock Platform client is not initialized or if the
             API call fails.
         """
         try:
@@ -255,15 +279,14 @@ class BedrockEncoder(DenseEncoder):
         raise ValueError("Bedrock call failed to return embeddings.")
 
     def chunk_strings(self, strings, MAX_WORDS=20):
-        """
-        Breaks up a list of strings into smaller chunks.
-
-        Args:
-            strings (list): A list of strings to be chunked.
-            max_chunk_size (int): The maximum size of each chunk. Default is 20.
-
-        Returns:
-            list: A list of lists, where each inner list contains a chunk of strings.
+        """Breaks up a list of strings into smaller chunks.
+
+        :param strings: A list of strings to be chunked.
+        :type strings: list
+        :param max_chunk_size: The maximum size of each chunk. Default is 20.
+        :type max_chunk_size: int
+        :returns: A list of lists, where each inner list contains a chunk of strings.
+        :rtype: list[list[str]]
         """
         encoding = tiktoken.get_encoding("cl100k_base")
         chunked_strings = []
@@ -280,17 +303,15 @@ class BedrockEncoder(DenseEncoder):
     def get_env_variable(var_name, provided_value, default=None):
         """Retrieves environment variable or uses a provided value.
 
-        Args:
-            var_name (str): The name of the environment variable.
-            provided_value (Optional[str]): The provided value to use if not None.
-            default (Optional[str]): The default value if the environment variable is not set.
-
-        Returns:
-            str: The value of the environment variable or the provided/default value.
-            None: Where AWS_SESSION_TOKEN is not set or provided
-
-        Raises:
-            ValueError: If no value is provided and the environment variable is not set.
+        :param var_name: The name of the environment variable.
+        :type var_name: str
+        :param provided_value: The provided value to use if not None.
+        :type provided_value: Optional[str]
+        :param default: The default value if the environment variable is not set.
+        :type default: Optional[str]
+        :returns: The value of the environment variable or the provided/default value.
+        :rtype: str
+        :raises ValueError: If no value is provided and the environment variable is not set.
         """
         if provided_value is not None:
             return provided_value
diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py
index 1ecfe41c..3dc33b99 100644
--- a/semantic_router/encoders/clip.py
+++ b/semantic_router/encoders/clip.py
@@ -7,6 +7,30 @@ from semantic_router.encoders import DenseEncoder
 
 
 class CLIPEncoder(DenseEncoder):
+    """Multi-modal dense encoder for text and images using CLIP-type models via
+    HuggingFace.
+
+    :param name: The name of the model to use.
+    :type name: str
+    :param tokenizer_kwargs: Keyword arguments for the tokenizer.
+    :type tokenizer_kwargs: Dict
+    :param processor_kwargs: Keyword arguments for the processor.
+    :type processor_kwargs: Dict
+    :param model_kwargs: Keyword arguments for the model.
+    :type model_kwargs: Dict
+    :param device: The device to use for the model.
+    :type device: Optional[str]
+    :param _tokenizer: The tokenizer for the model.
+    :type _tokenizer: Any
+    :param _processor: The processor for the model.
+    :type _processor: Any
+    :param _model: The model.
+    :type _model: Any
+    :param _torch: The torch library.
+    :type _torch: Any
+    :param _Image: The PIL library.
+    :type _Image: Any
+    """
     name: str = "openai/clip-vit-base-patch16"
     type: str = "huggingface"
     tokenizer_kwargs: Dict = {}
@@ -20,6 +44,11 @@ class CLIPEncoder(DenseEncoder):
     _Image: Any = PrivateAttr()
 
     def __init__(self, **data):
+        """Initialize the CLIPEncoder.
+
+        :param **data: Keyword arguments for the encoder.
+        :type **data: Dict
+        """
         if data.get("score_threshold") is None:
             data["score_threshold"] = 0.2
         super().__init__(**data)
@@ -31,6 +60,17 @@ class CLIPEncoder(DenseEncoder):
         batch_size: int = 32,
         normalize_embeddings: bool = True,
     ) -> List[List[float]]:
+        """Encode a list of documents. Can handle both text and images.
+
+        :param docs: The documents to encode.
+        :type docs: List[Any]
+        :param batch_size: The batch size for the encoding.
+        :type batch_size: int
+        :param normalize_embeddings: Whether to normalize the embeddings.
+        :type normalize_embeddings: bool
+        :returns: A list of embeddings.
+        :rtype: List[List[float]]
+        """
         all_embeddings = []
         if isinstance(docs[0], str):
             text = True
@@ -50,6 +90,11 @@ class CLIPEncoder(DenseEncoder):
         return all_embeddings
 
     def _initialize_hf_model(self):
+        """Initialize the HuggingFace model.
+
+        :returns: A tuple of the tokenizer, processor, and model.
+        :rtype: Tuple[Any, Any, Any]
+        """
         try:
             from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
         except ImportError:
@@ -92,6 +137,11 @@ class CLIPEncoder(DenseEncoder):
         return tokenizer, processor, model
 
     def _get_device(self) -> str:
+        """Get the device to use for the model. Returns either cuda, mps, or cpu.
+
+        :returns: The device to use for the model.
+        :rtype: str
+        """
         if self.device:
             device = self.device
         elif self._torch.cuda.is_available():
@@ -103,6 +153,13 @@ class CLIPEncoder(DenseEncoder):
         return device
 
     def _encode_text(self, docs: List[str]) -> Any:
+        """Encode a list of text documents.
+
+        :param docs: The documents to encode.
+        :type docs: List[str]
+        :returns: The embeddings for the documents.
+        :rtype: Any
+        """
         inputs = self._tokenizer(
             docs, return_tensors="pt", padding=True, truncation=True
         ).to(self.device)
@@ -112,6 +169,13 @@ class CLIPEncoder(DenseEncoder):
         return embeds
 
     def _encode_image(self, images: List[Any]) -> Any:
+        """Encode a list of image documents.
+
+        :param images: The images to encode.
+        :type images: List[Any]
+        :returns: The embeddings for the images.
+        :rtype: Any
+        """
         rgb_images = [self._ensure_rgb(img) for img in images]
         inputs = self._processor(text=None, images=rgb_images, return_tensors="pt")[
             "pixel_values"
@@ -122,6 +186,13 @@ class CLIPEncoder(DenseEncoder):
         return embeds
 
     def _ensure_rgb(self, img: Any):
+        """Ensure the image is in RGB format.
+
+        :param img: The image to ensure is in RGB format.
+        :type img: Any
+        :returns: The image in RGB format.
+        :rtype: Any
+        """
         rgbimg = self._Image.new("RGB", img.size)
         rgbimg.paste(img)
         return rgbimg
-- 
GitLab