From f16f620fb63f7b3264a6da275239441d88e74183 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Thu, 13 Feb 2025 10:22:03 +0400 Subject: [PATCH] chore: add docstrings for few encoders --- semantic_router/encoders/aurelio.py | 33 +++++++ semantic_router/encoders/base.py | 54 ++++++++++- semantic_router/encoders/bedrock.py | 139 ++++++++++++++++------------ semantic_router/encoders/clip.py | 71 ++++++++++++++ 4 files changed, 236 insertions(+), 61 deletions(-) diff --git a/semantic_router/encoders/aurelio.py b/semantic_router/encoders/aurelio.py index d1d40aa7..81e9f8e8 100644 --- a/semantic_router/encoders/aurelio.py +++ b/semantic_router/encoders/aurelio.py @@ -9,6 +9,9 @@ from semantic_router.schema import SparseEmbedding class AurelioSparseEncoder(SparseEncoder): + """Sparse encoder using Aurelio Platform's embedding API. Requires an API key from + https://platform.aurelio.ai + """ model: Optional[Any] = None client: AurelioClient = Field(default_factory=AurelioClient, exclude=True) async_client: AsyncAurelioClient = Field( @@ -21,6 +24,13 @@ class AurelioSparseEncoder(SparseEncoder): name: str | None = None, api_key: Optional[str] = None, ): + """Initialize the AurelioSparseEncoder. + + :param name: The name of the model to use. + :type name: str | None + :param api_key: The API key to use. + :type api_key: str | None + """ if name is None: name = "bm25" super().__init__(name=name) @@ -32,11 +42,28 @@ class AurelioSparseEncoder(SparseEncoder): self.async_client = AsyncAurelioClient(api_key=api_key) def __call__(self, docs: list[str]) -> list[SparseEmbedding]: + """Encode a list of documents using the Aurelio Platform embedding API. Documents + must be strings, sparse encoders do not support other types. + + :param docs: The documents to encode. + :type docs: list[str] + :return: The encoded documents. + :rtype: list[SparseEmbedding] + """ res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name) embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] return embeds async def acall(self, docs: list[str]) -> list[SparseEmbedding]: + """Asynchronously encode a list of documents using the Aurelio Platform + embedding API. Documents must be strings, sparse encoders do not support other + types. + + :param docs: The documents to encode. + :type docs: list[str] + :return: The encoded documents. + :rtype: list[SparseEmbedding] + """ res: EmbeddingResponse = await self.async_client.embedding( input=docs, model=self.name ) @@ -44,4 +71,10 @@ class AurelioSparseEncoder(SparseEncoder): return embeds def fit(self, docs: List[str]): + """Fit the encoder to a list of documents. AurelioSparseEncoder does not support + fit yet. + + :param docs: The documents to fit the encoder to. + :type docs: list[str] + """ raise NotImplementedError("AurelioSparseEncoder does not support fit.") diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 3e8cba21..864668ee 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -15,17 +15,44 @@ class DenseEncoder(BaseModel): arbitrary_types_allowed = True @field_validator("score_threshold") - def set_score_threshold(cls, v): + def set_score_threshold(cls, v: float | None) -> float | None: + """Set the score threshold. If None, the score threshold is not used. + + :param v: The score threshold. + :type v: float | None + :return: The score threshold. + :rtype: float | None + """ return float(v) if v is not None else None def __call__(self, docs: List[Any]) -> List[List[float]]: + """Encode a list of documents. Documents can be any type, but the encoder must + be built to handle that data type. Typically, these types are strings or + arrays representing images. + + :param docs: The documents to encode. + :type docs: List[Any] + :return: The encoded documents. + :rtype: List[List[float]] + """ raise NotImplementedError("Subclasses must implement this method") def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]: + """Encode a list of documents asynchronously. Documents can be any type, but the + encoder must be built to handle that data type. Typically, these types are + strings or arrays representing images. + + :param docs: The documents to encode. + :type docs: List[Any] + :return: The encoded documents. + :rtype: List[List[float]] + """ raise NotImplementedError("Subclasses must implement this method") class SparseEncoder(BaseModel): + """An encoder that encodes documents into a sparse format. + """ name: str type: str = Field(default="base") @@ -33,15 +60,38 @@ class SparseEncoder(BaseModel): arbitrary_types_allowed = True def __call__(self, docs: List[str]) -> List[SparseEmbedding]: + """Encode a list of documents. Documents must be strings, sparse encoders do not + support other types. + + :param docs: The documents to encode. + :type docs: List[str] + :return: The encoded documents. + :rtype: List[SparseEmbedding] + """ raise NotImplementedError("Subclasses must implement this method") async def acall(self, docs: List[str]) -> list[SparseEmbedding]: + """Encode a list of documents. Documents must be strings, sparse encoders do not + support other types. + + :param docs: The documents to encode. + :type docs: List[str] + :return: The encoded documents. + :rtype: List[SparseEmbedding] + """ raise NotImplementedError("Subclasses must implement this method") def _array_to_sparse_embeddings( self, sparse_arrays: np.ndarray ) -> List[SparseEmbedding]: - """Consumes several sparse vectors containing zero-values and returns a compact array.""" + """Consumes several sparse vectors containing zero-values and returns a compact + array. + + :param sparse_arrays: The sparse arrays to compact. + :type sparse_arrays: np.ndarray + :return: The compact array. + :rtype: List[SparseEmbedding] + """ if sparse_arrays.ndim != 2: raise ValueError(f"Expected a 2D array, got a {sparse_arrays.ndim}D array.") # get coordinates of non-zero values diff --git a/semantic_router/encoders/bedrock.py b/semantic_router/encoders/bedrock.py index f200bea8..1ff8723f 100644 --- a/semantic_router/encoders/bedrock.py +++ b/semantic_router/encoders/bedrock.py @@ -29,6 +29,27 @@ from semantic_router.utils.logger import logger class BedrockEncoder(DenseEncoder): + """Dense encoder using Amazon Bedrock embedding API. Requires an AWS Access Key ID + and AWS Secret Access Key. + + The BedrockEncoder class is a subclass of DenseEncoder and utilizes the + TextEmbeddingModel from the Amazon's Bedrock Platform to generate embeddings for + given documents. It supports customization of the pre-trained model, score + threshold, and region. + + Example usage: + + ```python + from semantic_router.encoders.bedrock_encoder import BedrockEncoder + + encoder = BedrockEncoder( + access_key_id="your-access-key-id", + secret_access_key="your-secret-key", + region="your-region" + ) + embeddings = encoder(["document1", "document2"]) + ``` + """ client: Any = None type: str = "bedrock" input_type: Optional[str] = "search_query" @@ -50,26 +71,32 @@ class BedrockEncoder(DenseEncoder): ): """Initializes the BedrockEncoder. - Args: - name: The name of the pre-trained model to use for embedding. - If not provided, the default model specified in EncoderDefault will - be used. - score_threshold: The threshold for similarity scores. - access_key_id: The AWS access key id for an IAM principle. - If not provided, it will be retrieved from the access_key_id - environment variable. - secret_access_key: The secret access key for an IAM principle. - If not provided, it will be retrieved from the AWS_SECRET_KEY - environment variable. - session_token: The session token for an IAM principle. - If not provided, it will be retrieved from the AWS_SESSION_TOKEN - environment variable. - region: The location of the Bedrock resources. - If not provided, it will be retrieved from the AWS_REGION - environment variable, defaulting to "us-west-1" - - Raises: - ValueError: If the Bedrock Platform client fails to initialize. + :param name: The name of the pre-trained model to use for embedding. + If not provided, the default model specified in EncoderDefault will + be used. + :type name: str + :param input_type: The type of input to use for the embedding. + If not provided, the default input type specified in EncoderDefault will + be used. + :type input_type: str + :param score_threshold: The threshold for similarity scores. + :type score_threshold: float + :param access_key_id: The AWS access key id for an IAM principle. + If not provided, it will be retrieved from the access_key_id + environment variable. + :type access_key_id: str + :param secret_access_key: The secret access key for an IAM principle. + If not provided, it will be retrieved from the AWS_SECRET_KEY + environment variable. + :type secret_access_key: str + :param session_token: The session token for an IAM principle. + If not provided, it will be retrieved from the AWS_SESSION_TOKEN + environment variable. + :param region: The location of the Bedrock resources. + If not provided, it will be retrieved from the AWS_REGION + environment variable, defaulting to "us-west-1" + :type region: str + :raises ValueError: If the Bedrock Platform client fails to initialize. """ super().__init__(name=name, score_threshold=score_threshold) self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id) @@ -96,16 +123,15 @@ class BedrockEncoder(DenseEncoder): ): """Initializes the Bedrock client. - Args: - access_key_id: The Amazon access key ID. - secret_access_key: The Amazon secret key. - region: The location of the AI Platform resources. - - Returns: - An instance of the TextEmbeddingModel client. - - Raises: - ImportError: If the required Bedrock libraries are not + :param access_key_id: The Amazon access key ID. + :type access_key_id: str + :param secret_access_key: The Amazon secret key. + :type secret_access_key: str + :param region: The location of the AI Platform resources. + :type region: str + :returns: An instance of the TextEmbeddingModel client. + :rtype: Any + :raises ImportError: If the required Bedrock libraries are not installed. ValueError: If the Bedrock client fails to initialize. """ @@ -145,16 +171,14 @@ class BedrockEncoder(DenseEncoder): ) -> List[List[float]]: """Generates embeddings for the given documents. - Args: - docs: A list of strings representing the documents to embed. - model_kwargs: A dictionary of model-specific inference parameters. - - Returns: - A list of lists, where each inner list contains the embedding values for a + :param docs: A list of strings representing the documents to embed. + :type docs: list[str] + :param model_kwargs: A dictionary of model-specific inference parameters. + :type model_kwargs: dict + :returns: A list of lists, where each inner list contains the embedding values for a document. - - Raises: - ValueError: If the Bedrock Platform client is not initialized or if the + :rtype: list[list[float]] + :raises ValueError: If the Bedrock Platform client is not initialized or if the API call fails. """ try: @@ -255,15 +279,14 @@ class BedrockEncoder(DenseEncoder): raise ValueError("Bedrock call failed to return embeddings.") def chunk_strings(self, strings, MAX_WORDS=20): - """ - Breaks up a list of strings into smaller chunks. - - Args: - strings (list): A list of strings to be chunked. - max_chunk_size (int): The maximum size of each chunk. Default is 20. - - Returns: - list: A list of lists, where each inner list contains a chunk of strings. + """Breaks up a list of strings into smaller chunks. + + :param strings: A list of strings to be chunked. + :type strings: list + :param max_chunk_size: The maximum size of each chunk. Default is 20. + :type max_chunk_size: int + :returns: A list of lists, where each inner list contains a chunk of strings. + :rtype: list[list[str]] """ encoding = tiktoken.get_encoding("cl100k_base") chunked_strings = [] @@ -280,17 +303,15 @@ class BedrockEncoder(DenseEncoder): def get_env_variable(var_name, provided_value, default=None): """Retrieves environment variable or uses a provided value. - Args: - var_name (str): The name of the environment variable. - provided_value (Optional[str]): The provided value to use if not None. - default (Optional[str]): The default value if the environment variable is not set. - - Returns: - str: The value of the environment variable or the provided/default value. - None: Where AWS_SESSION_TOKEN is not set or provided - - Raises: - ValueError: If no value is provided and the environment variable is not set. + :param var_name: The name of the environment variable. + :type var_name: str + :param provided_value: The provided value to use if not None. + :type provided_value: Optional[str] + :param default: The default value if the environment variable is not set. + :type default: Optional[str] + :returns: The value of the environment variable or the provided/default value. + :rtype: str + :raises ValueError: If no value is provided and the environment variable is not set. """ if provided_value is not None: return provided_value diff --git a/semantic_router/encoders/clip.py b/semantic_router/encoders/clip.py index 1ecfe41c..3dc33b99 100644 --- a/semantic_router/encoders/clip.py +++ b/semantic_router/encoders/clip.py @@ -7,6 +7,30 @@ from semantic_router.encoders import DenseEncoder class CLIPEncoder(DenseEncoder): + """Multi-modal dense encoder for text and images using CLIP-type models via + HuggingFace. + + :param name: The name of the model to use. + :type name: str + :param tokenizer_kwargs: Keyword arguments for the tokenizer. + :type tokenizer_kwargs: Dict + :param processor_kwargs: Keyword arguments for the processor. + :type processor_kwargs: Dict + :param model_kwargs: Keyword arguments for the model. + :type model_kwargs: Dict + :param device: The device to use for the model. + :type device: Optional[str] + :param _tokenizer: The tokenizer for the model. + :type _tokenizer: Any + :param _processor: The processor for the model. + :type _processor: Any + :param _model: The model. + :type _model: Any + :param _torch: The torch library. + :type _torch: Any + :param _Image: The PIL library. + :type _Image: Any + """ name: str = "openai/clip-vit-base-patch16" type: str = "huggingface" tokenizer_kwargs: Dict = {} @@ -20,6 +44,11 @@ class CLIPEncoder(DenseEncoder): _Image: Any = PrivateAttr() def __init__(self, **data): + """Initialize the CLIPEncoder. + + :param **data: Keyword arguments for the encoder. + :type **data: Dict + """ if data.get("score_threshold") is None: data["score_threshold"] = 0.2 super().__init__(**data) @@ -31,6 +60,17 @@ class CLIPEncoder(DenseEncoder): batch_size: int = 32, normalize_embeddings: bool = True, ) -> List[List[float]]: + """Encode a list of documents. Can handle both text and images. + + :param docs: The documents to encode. + :type docs: List[Any] + :param batch_size: The batch size for the encoding. + :type batch_size: int + :param normalize_embeddings: Whether to normalize the embeddings. + :type normalize_embeddings: bool + :returns: A list of embeddings. + :rtype: List[List[float]] + """ all_embeddings = [] if isinstance(docs[0], str): text = True @@ -50,6 +90,11 @@ class CLIPEncoder(DenseEncoder): return all_embeddings def _initialize_hf_model(self): + """Initialize the HuggingFace model. + + :returns: A tuple of the tokenizer, processor, and model. + :rtype: Tuple[Any, Any, Any] + """ try: from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast except ImportError: @@ -92,6 +137,11 @@ class CLIPEncoder(DenseEncoder): return tokenizer, processor, model def _get_device(self) -> str: + """Get the device to use for the model. Returns either cuda, mps, or cpu. + + :returns: The device to use for the model. + :rtype: str + """ if self.device: device = self.device elif self._torch.cuda.is_available(): @@ -103,6 +153,13 @@ class CLIPEncoder(DenseEncoder): return device def _encode_text(self, docs: List[str]) -> Any: + """Encode a list of text documents. + + :param docs: The documents to encode. + :type docs: List[str] + :returns: The embeddings for the documents. + :rtype: Any + """ inputs = self._tokenizer( docs, return_tensors="pt", padding=True, truncation=True ).to(self.device) @@ -112,6 +169,13 @@ class CLIPEncoder(DenseEncoder): return embeds def _encode_image(self, images: List[Any]) -> Any: + """Encode a list of image documents. + + :param images: The images to encode. + :type images: List[Any] + :returns: The embeddings for the images. + :rtype: Any + """ rgb_images = [self._ensure_rgb(img) for img in images] inputs = self._processor(text=None, images=rgb_images, return_tensors="pt")[ "pixel_values" @@ -122,6 +186,13 @@ class CLIPEncoder(DenseEncoder): return embeds def _ensure_rgb(self, img: Any): + """Ensure the image is in RGB format. + + :param img: The image to ensure is in RGB format. + :type img: Any + :returns: The image in RGB format. + :rtype: Any + """ rgbimg = self._Image.new("RGB", img.size) rgbimg.paste(img) return rgbimg -- GitLab