Skip to content
Snippets Groups Projects
Commit f16f620f authored by James Briggs's avatar James Briggs
Browse files

chore: add docstrings for few encoders

parent bae46188
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,9 @@ from semantic_router.schema import SparseEmbedding ...@@ -9,6 +9,9 @@ from semantic_router.schema import SparseEmbedding
class AurelioSparseEncoder(SparseEncoder): class AurelioSparseEncoder(SparseEncoder):
"""Sparse encoder using Aurelio Platform's embedding API. Requires an API key from
https://platform.aurelio.ai
"""
model: Optional[Any] = None model: Optional[Any] = None
client: AurelioClient = Field(default_factory=AurelioClient, exclude=True) client: AurelioClient = Field(default_factory=AurelioClient, exclude=True)
async_client: AsyncAurelioClient = Field( async_client: AsyncAurelioClient = Field(
...@@ -21,6 +24,13 @@ class AurelioSparseEncoder(SparseEncoder): ...@@ -21,6 +24,13 @@ class AurelioSparseEncoder(SparseEncoder):
name: str | None = None, name: str | None = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
): ):
"""Initialize the AurelioSparseEncoder.
:param name: The name of the model to use.
:type name: str | None
:param api_key: The API key to use.
:type api_key: str | None
"""
if name is None: if name is None:
name = "bm25" name = "bm25"
super().__init__(name=name) super().__init__(name=name)
...@@ -32,11 +42,28 @@ class AurelioSparseEncoder(SparseEncoder): ...@@ -32,11 +42,28 @@ class AurelioSparseEncoder(SparseEncoder):
self.async_client = AsyncAurelioClient(api_key=api_key) self.async_client = AsyncAurelioClient(api_key=api_key)
def __call__(self, docs: list[str]) -> list[SparseEmbedding]: def __call__(self, docs: list[str]) -> list[SparseEmbedding]:
"""Encode a list of documents using the Aurelio Platform embedding API. Documents
must be strings, sparse encoders do not support other types.
:param docs: The documents to encode.
:type docs: list[str]
:return: The encoded documents.
:rtype: list[SparseEmbedding]
"""
res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name) res: EmbeddingResponse = self.client.embedding(input=docs, model=self.name)
embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data] embeds = [SparseEmbedding.from_aurelio(r.embedding) for r in res.data]
return embeds return embeds
async def acall(self, docs: list[str]) -> list[SparseEmbedding]: async def acall(self, docs: list[str]) -> list[SparseEmbedding]:
"""Asynchronously encode a list of documents using the Aurelio Platform
embedding API. Documents must be strings, sparse encoders do not support other
types.
:param docs: The documents to encode.
:type docs: list[str]
:return: The encoded documents.
:rtype: list[SparseEmbedding]
"""
res: EmbeddingResponse = await self.async_client.embedding( res: EmbeddingResponse = await self.async_client.embedding(
input=docs, model=self.name input=docs, model=self.name
) )
...@@ -44,4 +71,10 @@ class AurelioSparseEncoder(SparseEncoder): ...@@ -44,4 +71,10 @@ class AurelioSparseEncoder(SparseEncoder):
return embeds return embeds
def fit(self, docs: List[str]): def fit(self, docs: List[str]):
"""Fit the encoder to a list of documents. AurelioSparseEncoder does not support
fit yet.
:param docs: The documents to fit the encoder to.
:type docs: list[str]
"""
raise NotImplementedError("AurelioSparseEncoder does not support fit.") raise NotImplementedError("AurelioSparseEncoder does not support fit.")
...@@ -15,17 +15,44 @@ class DenseEncoder(BaseModel): ...@@ -15,17 +15,44 @@ class DenseEncoder(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@field_validator("score_threshold") @field_validator("score_threshold")
def set_score_threshold(cls, v): def set_score_threshold(cls, v: float | None) -> float | None:
"""Set the score threshold. If None, the score threshold is not used.
:param v: The score threshold.
:type v: float | None
:return: The score threshold.
:rtype: float | None
"""
return float(v) if v is not None else None return float(v) if v is not None else None
def __call__(self, docs: List[Any]) -> List[List[float]]: def __call__(self, docs: List[Any]) -> List[List[float]]:
"""Encode a list of documents. Documents can be any type, but the encoder must
be built to handle that data type. Typically, these types are strings or
arrays representing images.
:param docs: The documents to encode.
:type docs: List[Any]
:return: The encoded documents.
:rtype: List[List[float]]
"""
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]: def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]:
"""Encode a list of documents asynchronously. Documents can be any type, but the
encoder must be built to handle that data type. Typically, these types are
strings or arrays representing images.
:param docs: The documents to encode.
:type docs: List[Any]
:return: The encoded documents.
:rtype: List[List[float]]
"""
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
class SparseEncoder(BaseModel): class SparseEncoder(BaseModel):
"""An encoder that encodes documents into a sparse format.
"""
name: str name: str
type: str = Field(default="base") type: str = Field(default="base")
...@@ -33,15 +60,38 @@ class SparseEncoder(BaseModel): ...@@ -33,15 +60,38 @@ class SparseEncoder(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __call__(self, docs: List[str]) -> List[SparseEmbedding]: def __call__(self, docs: List[str]) -> List[SparseEmbedding]:
"""Encode a list of documents. Documents must be strings, sparse encoders do not
support other types.
:param docs: The documents to encode.
:type docs: List[str]
:return: The encoded documents.
:rtype: List[SparseEmbedding]
"""
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
async def acall(self, docs: List[str]) -> list[SparseEmbedding]: async def acall(self, docs: List[str]) -> list[SparseEmbedding]:
"""Encode a list of documents. Documents must be strings, sparse encoders do not
support other types.
:param docs: The documents to encode.
:type docs: List[str]
:return: The encoded documents.
:rtype: List[SparseEmbedding]
"""
raise NotImplementedError("Subclasses must implement this method") raise NotImplementedError("Subclasses must implement this method")
def _array_to_sparse_embeddings( def _array_to_sparse_embeddings(
self, sparse_arrays: np.ndarray self, sparse_arrays: np.ndarray
) -> List[SparseEmbedding]: ) -> List[SparseEmbedding]:
"""Consumes several sparse vectors containing zero-values and returns a compact array.""" """Consumes several sparse vectors containing zero-values and returns a compact
array.
:param sparse_arrays: The sparse arrays to compact.
:type sparse_arrays: np.ndarray
:return: The compact array.
:rtype: List[SparseEmbedding]
"""
if sparse_arrays.ndim != 2: if sparse_arrays.ndim != 2:
raise ValueError(f"Expected a 2D array, got a {sparse_arrays.ndim}D array.") raise ValueError(f"Expected a 2D array, got a {sparse_arrays.ndim}D array.")
# get coordinates of non-zero values # get coordinates of non-zero values
......
...@@ -29,6 +29,27 @@ from semantic_router.utils.logger import logger ...@@ -29,6 +29,27 @@ from semantic_router.utils.logger import logger
class BedrockEncoder(DenseEncoder): class BedrockEncoder(DenseEncoder):
"""Dense encoder using Amazon Bedrock embedding API. Requires an AWS Access Key ID
and AWS Secret Access Key.
The BedrockEncoder class is a subclass of DenseEncoder and utilizes the
TextEmbeddingModel from the Amazon's Bedrock Platform to generate embeddings for
given documents. It supports customization of the pre-trained model, score
threshold, and region.
Example usage:
```python
from semantic_router.encoders.bedrock_encoder import BedrockEncoder
encoder = BedrockEncoder(
access_key_id="your-access-key-id",
secret_access_key="your-secret-key",
region="your-region"
)
embeddings = encoder(["document1", "document2"])
```
"""
client: Any = None client: Any = None
type: str = "bedrock" type: str = "bedrock"
input_type: Optional[str] = "search_query" input_type: Optional[str] = "search_query"
...@@ -50,26 +71,32 @@ class BedrockEncoder(DenseEncoder): ...@@ -50,26 +71,32 @@ class BedrockEncoder(DenseEncoder):
): ):
"""Initializes the BedrockEncoder. """Initializes the BedrockEncoder.
Args: :param name: The name of the pre-trained model to use for embedding.
name: The name of the pre-trained model to use for embedding. If not provided, the default model specified in EncoderDefault will
If not provided, the default model specified in EncoderDefault will be used.
be used. :type name: str
score_threshold: The threshold for similarity scores. :param input_type: The type of input to use for the embedding.
access_key_id: The AWS access key id for an IAM principle. If not provided, the default input type specified in EncoderDefault will
If not provided, it will be retrieved from the access_key_id be used.
environment variable. :type input_type: str
secret_access_key: The secret access key for an IAM principle. :param score_threshold: The threshold for similarity scores.
If not provided, it will be retrieved from the AWS_SECRET_KEY :type score_threshold: float
environment variable. :param access_key_id: The AWS access key id for an IAM principle.
session_token: The session token for an IAM principle. If not provided, it will be retrieved from the access_key_id
If not provided, it will be retrieved from the AWS_SESSION_TOKEN environment variable.
environment variable. :type access_key_id: str
region: The location of the Bedrock resources. :param secret_access_key: The secret access key for an IAM principle.
If not provided, it will be retrieved from the AWS_REGION If not provided, it will be retrieved from the AWS_SECRET_KEY
environment variable, defaulting to "us-west-1" environment variable.
:type secret_access_key: str
Raises: :param session_token: The session token for an IAM principle.
ValueError: If the Bedrock Platform client fails to initialize. If not provided, it will be retrieved from the AWS_SESSION_TOKEN
environment variable.
:param region: The location of the Bedrock resources.
If not provided, it will be retrieved from the AWS_REGION
environment variable, defaulting to "us-west-1"
:type region: str
:raises ValueError: If the Bedrock Platform client fails to initialize.
""" """
super().__init__(name=name, score_threshold=score_threshold) super().__init__(name=name, score_threshold=score_threshold)
self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id) self.access_key_id = self.get_env_variable("AWS_ACCESS_KEY_ID", access_key_id)
...@@ -96,16 +123,15 @@ class BedrockEncoder(DenseEncoder): ...@@ -96,16 +123,15 @@ class BedrockEncoder(DenseEncoder):
): ):
"""Initializes the Bedrock client. """Initializes the Bedrock client.
Args: :param access_key_id: The Amazon access key ID.
access_key_id: The Amazon access key ID. :type access_key_id: str
secret_access_key: The Amazon secret key. :param secret_access_key: The Amazon secret key.
region: The location of the AI Platform resources. :type secret_access_key: str
:param region: The location of the AI Platform resources.
Returns: :type region: str
An instance of the TextEmbeddingModel client. :returns: An instance of the TextEmbeddingModel client.
:rtype: Any
Raises: :raises ImportError: If the required Bedrock libraries are not
ImportError: If the required Bedrock libraries are not
installed. installed.
ValueError: If the Bedrock client fails to initialize. ValueError: If the Bedrock client fails to initialize.
""" """
...@@ -145,16 +171,14 @@ class BedrockEncoder(DenseEncoder): ...@@ -145,16 +171,14 @@ class BedrockEncoder(DenseEncoder):
) -> List[List[float]]: ) -> List[List[float]]:
"""Generates embeddings for the given documents. """Generates embeddings for the given documents.
Args: :param docs: A list of strings representing the documents to embed.
docs: A list of strings representing the documents to embed. :type docs: list[str]
model_kwargs: A dictionary of model-specific inference parameters. :param model_kwargs: A dictionary of model-specific inference parameters.
:type model_kwargs: dict
Returns: :returns: A list of lists, where each inner list contains the embedding values for a
A list of lists, where each inner list contains the embedding values for a
document. document.
:rtype: list[list[float]]
Raises: :raises ValueError: If the Bedrock Platform client is not initialized or if the
ValueError: If the Bedrock Platform client is not initialized or if the
API call fails. API call fails.
""" """
try: try:
...@@ -255,15 +279,14 @@ class BedrockEncoder(DenseEncoder): ...@@ -255,15 +279,14 @@ class BedrockEncoder(DenseEncoder):
raise ValueError("Bedrock call failed to return embeddings.") raise ValueError("Bedrock call failed to return embeddings.")
def chunk_strings(self, strings, MAX_WORDS=20): def chunk_strings(self, strings, MAX_WORDS=20):
""" """Breaks up a list of strings into smaller chunks.
Breaks up a list of strings into smaller chunks.
:param strings: A list of strings to be chunked.
Args: :type strings: list
strings (list): A list of strings to be chunked. :param max_chunk_size: The maximum size of each chunk. Default is 20.
max_chunk_size (int): The maximum size of each chunk. Default is 20. :type max_chunk_size: int
:returns: A list of lists, where each inner list contains a chunk of strings.
Returns: :rtype: list[list[str]]
list: A list of lists, where each inner list contains a chunk of strings.
""" """
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
chunked_strings = [] chunked_strings = []
...@@ -280,17 +303,15 @@ class BedrockEncoder(DenseEncoder): ...@@ -280,17 +303,15 @@ class BedrockEncoder(DenseEncoder):
def get_env_variable(var_name, provided_value, default=None): def get_env_variable(var_name, provided_value, default=None):
"""Retrieves environment variable or uses a provided value. """Retrieves environment variable or uses a provided value.
Args: :param var_name: The name of the environment variable.
var_name (str): The name of the environment variable. :type var_name: str
provided_value (Optional[str]): The provided value to use if not None. :param provided_value: The provided value to use if not None.
default (Optional[str]): The default value if the environment variable is not set. :type provided_value: Optional[str]
:param default: The default value if the environment variable is not set.
Returns: :type default: Optional[str]
str: The value of the environment variable or the provided/default value. :returns: The value of the environment variable or the provided/default value.
None: Where AWS_SESSION_TOKEN is not set or provided :rtype: str
:raises ValueError: If no value is provided and the environment variable is not set.
Raises:
ValueError: If no value is provided and the environment variable is not set.
""" """
if provided_value is not None: if provided_value is not None:
return provided_value return provided_value
......
...@@ -7,6 +7,30 @@ from semantic_router.encoders import DenseEncoder ...@@ -7,6 +7,30 @@ from semantic_router.encoders import DenseEncoder
class CLIPEncoder(DenseEncoder): class CLIPEncoder(DenseEncoder):
"""Multi-modal dense encoder for text and images using CLIP-type models via
HuggingFace.
:param name: The name of the model to use.
:type name: str
:param tokenizer_kwargs: Keyword arguments for the tokenizer.
:type tokenizer_kwargs: Dict
:param processor_kwargs: Keyword arguments for the processor.
:type processor_kwargs: Dict
:param model_kwargs: Keyword arguments for the model.
:type model_kwargs: Dict
:param device: The device to use for the model.
:type device: Optional[str]
:param _tokenizer: The tokenizer for the model.
:type _tokenizer: Any
:param _processor: The processor for the model.
:type _processor: Any
:param _model: The model.
:type _model: Any
:param _torch: The torch library.
:type _torch: Any
:param _Image: The PIL library.
:type _Image: Any
"""
name: str = "openai/clip-vit-base-patch16" name: str = "openai/clip-vit-base-patch16"
type: str = "huggingface" type: str = "huggingface"
tokenizer_kwargs: Dict = {} tokenizer_kwargs: Dict = {}
...@@ -20,6 +44,11 @@ class CLIPEncoder(DenseEncoder): ...@@ -20,6 +44,11 @@ class CLIPEncoder(DenseEncoder):
_Image: Any = PrivateAttr() _Image: Any = PrivateAttr()
def __init__(self, **data): def __init__(self, **data):
"""Initialize the CLIPEncoder.
:param **data: Keyword arguments for the encoder.
:type **data: Dict
"""
if data.get("score_threshold") is None: if data.get("score_threshold") is None:
data["score_threshold"] = 0.2 data["score_threshold"] = 0.2
super().__init__(**data) super().__init__(**data)
...@@ -31,6 +60,17 @@ class CLIPEncoder(DenseEncoder): ...@@ -31,6 +60,17 @@ class CLIPEncoder(DenseEncoder):
batch_size: int = 32, batch_size: int = 32,
normalize_embeddings: bool = True, normalize_embeddings: bool = True,
) -> List[List[float]]: ) -> List[List[float]]:
"""Encode a list of documents. Can handle both text and images.
:param docs: The documents to encode.
:type docs: List[Any]
:param batch_size: The batch size for the encoding.
:type batch_size: int
:param normalize_embeddings: Whether to normalize the embeddings.
:type normalize_embeddings: bool
:returns: A list of embeddings.
:rtype: List[List[float]]
"""
all_embeddings = [] all_embeddings = []
if isinstance(docs[0], str): if isinstance(docs[0], str):
text = True text = True
...@@ -50,6 +90,11 @@ class CLIPEncoder(DenseEncoder): ...@@ -50,6 +90,11 @@ class CLIPEncoder(DenseEncoder):
return all_embeddings return all_embeddings
def _initialize_hf_model(self): def _initialize_hf_model(self):
"""Initialize the HuggingFace model.
:returns: A tuple of the tokenizer, processor, and model.
:rtype: Tuple[Any, Any, Any]
"""
try: try:
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
except ImportError: except ImportError:
...@@ -92,6 +137,11 @@ class CLIPEncoder(DenseEncoder): ...@@ -92,6 +137,11 @@ class CLIPEncoder(DenseEncoder):
return tokenizer, processor, model return tokenizer, processor, model
def _get_device(self) -> str: def _get_device(self) -> str:
"""Get the device to use for the model. Returns either cuda, mps, or cpu.
:returns: The device to use for the model.
:rtype: str
"""
if self.device: if self.device:
device = self.device device = self.device
elif self._torch.cuda.is_available(): elif self._torch.cuda.is_available():
...@@ -103,6 +153,13 @@ class CLIPEncoder(DenseEncoder): ...@@ -103,6 +153,13 @@ class CLIPEncoder(DenseEncoder):
return device return device
def _encode_text(self, docs: List[str]) -> Any: def _encode_text(self, docs: List[str]) -> Any:
"""Encode a list of text documents.
:param docs: The documents to encode.
:type docs: List[str]
:returns: The embeddings for the documents.
:rtype: Any
"""
inputs = self._tokenizer( inputs = self._tokenizer(
docs, return_tensors="pt", padding=True, truncation=True docs, return_tensors="pt", padding=True, truncation=True
).to(self.device) ).to(self.device)
...@@ -112,6 +169,13 @@ class CLIPEncoder(DenseEncoder): ...@@ -112,6 +169,13 @@ class CLIPEncoder(DenseEncoder):
return embeds return embeds
def _encode_image(self, images: List[Any]) -> Any: def _encode_image(self, images: List[Any]) -> Any:
"""Encode a list of image documents.
:param images: The images to encode.
:type images: List[Any]
:returns: The embeddings for the images.
:rtype: Any
"""
rgb_images = [self._ensure_rgb(img) for img in images] rgb_images = [self._ensure_rgb(img) for img in images]
inputs = self._processor(text=None, images=rgb_images, return_tensors="pt")[ inputs = self._processor(text=None, images=rgb_images, return_tensors="pt")[
"pixel_values" "pixel_values"
...@@ -122,6 +186,13 @@ class CLIPEncoder(DenseEncoder): ...@@ -122,6 +186,13 @@ class CLIPEncoder(DenseEncoder):
return embeds return embeds
def _ensure_rgb(self, img: Any): def _ensure_rgb(self, img: Any):
"""Ensure the image is in RGB format.
:param img: The image to ensure is in RGB format.
:type img: Any
:returns: The image in RGB format.
:rtype: Any
"""
rgbimg = self._Image.new("RGB", img.size) rgbimg = self._Image.new("RGB", img.size)
rgbimg.paste(img) rgbimg.paste(img)
return rgbimg return rgbimg
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment