From bcaa22ec5dc3e2c6b3501fb607e954dd8c112bb7 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Thu, 13 Feb 2025 12:00:17 +0400 Subject: [PATCH] feat: further docstrings and cleanup --- semantic_router/encoders/cohere.py | 27 +++++- semantic_router/encoders/fastembed.py | 26 ++++- semantic_router/encoders/google.py | 101 +++++++++---------- semantic_router/encoders/huggingface.py | 118 +++++++++++++++-------- semantic_router/encoders/mistral.py | 25 ++++- semantic_router/encoders/openai.py | 37 +++++++ semantic_router/encoders/vit.py | 44 +++++++++ semantic_router/encoders/zure.py | 39 +++++++- semantic_router/linear.py | 18 ++-- semantic_router/llms/base.py | 78 ++++++++++++++- semantic_router/llms/cohere.py | 27 ++++++ semantic_router/llms/llamacpp.py | 37 +++++++ semantic_router/llms/mistral.py | 28 +++++- semantic_router/llms/ollama.py | 27 ++++++ semantic_router/llms/openai.py | 89 ++++++++++++++++- semantic_router/llms/openrouter.py | 22 +++++ semantic_router/llms/zure.py | 24 +++++ semantic_router/route.py | 87 ++++++++++++++--- semantic_router/schema.py | 123 ++++++++++++++++++++++-- semantic_router/utils/defaults.py | 3 + semantic_router/utils/function_call.py | 75 +++++++++++++++ semantic_router/utils/llm.py | 65 ------------- semantic_router/utils/logger.py | 7 ++ 23 files changed, 926 insertions(+), 201 deletions(-) delete mode 100644 semantic_router/utils/llm.py diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index e919bae1..a021e662 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -8,6 +8,9 @@ from semantic_router.utils.defaults import EncoderDefault class CohereEncoder(DenseEncoder): + """Dense encoder that uses Cohere API to embed documents. Supports text only. Requires + a Cohere API key from https://dashboard.cohere.com/api-keys. + """ _client: Any = PrivateAttr() _embed_type: Any = PrivateAttr() type: str = "cohere" @@ -20,6 +23,18 @@ class CohereEncoder(DenseEncoder): score_threshold: float = 0.3, input_type: Optional[str] = "search_query", ): + """Initialize the Cohere encoder. + + :param name: The name of the embedding model to use. + :type name: str + :param cohere_api_key: The API key for the Cohere client, can also + be set via the COHERE_API_KEY environment variable. + :type cohere_api_key: str + :param score_threshold: The threshold for the score of the embedding. + :type score_threshold: float + :param input_type: The type of input to embed. + :type input_type: str + """ if name is None: name = EncoderDefault.COHERE.value["embedding_model"] super().__init__( @@ -34,9 +49,10 @@ class CohereEncoder(DenseEncoder): """Initializes the Cohere client. :param cohere_api_key: The API key for the Cohere client, can also - be set via the COHERE_API_KEY environment variable. - + be set via the COHERE_API_KEY environment variable. + :type cohere_api_key: str :return: An instance of the Cohere client. + :rtype: cohere.Client """ try: import cohere @@ -61,6 +77,13 @@ class CohereEncoder(DenseEncoder): return client def __call__(self, docs: List[str]) -> List[List[float]]: + """Embed a list of documents. Supports text only. + + :param docs: The documents to embed. + :type docs: List[str] + :return: The vector embeddings of the documents. + :rtype: List[List[float]] + """ if self._client is None: raise ValueError("Cohere client is not initialized.") try: diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py index 2c997795..db84feb0 100644 --- a/semantic_router/encoders/fastembed.py +++ b/semantic_router/encoders/fastembed.py @@ -7,6 +7,14 @@ from semantic_router.encoders import DenseEncoder class FastEmbedEncoder(DenseEncoder): + """Dense encoder that uses local FastEmbed to embed documents. Supports text only. + Requires the fastembed package which can be installed with `pip install 'semantic-router[fastembed]'` + + :param name: The name of the embedding model to use. + :param max_length: The maximum length of the input text. + :param cache_dir: The directory to cache the embedding model. + :param threads: The number of threads to use for the embedding. + """ type: str = "fastembed" name: str = "BAAI/bge-small-en-v1.5" max_length: int = 512 @@ -16,11 +24,19 @@ class FastEmbedEncoder(DenseEncoder): def __init__( self, score_threshold: float = 0.5, **data - ): # TODO default score_threshold not thoroughly tested, should optimize + ): + """Initialize the FastEmbed encoder. + + :param score_threshold: The threshold for the score of the embedding. + :type score_threshold: float + """ + # TODO default score_threshold not thoroughly tested, should optimize super().__init__(score_threshold=score_threshold, **data) self._client = self._initialize_client() def _initialize_client(self): + """Initialize the FastEmbed library. Requires the fastembed package. + """ try: from fastembed import TextEmbedding except ImportError: @@ -43,6 +59,14 @@ class FastEmbedEncoder(DenseEncoder): return embedding def __call__(self, docs: List[str]) -> List[List[float]]: + """Embed a list of documents. Supports text only. + + :param docs: The documents to embed. + :type docs: List[str] + :raise ValueError: If the embedding fails. + :return: The vector embeddings of the documents. + :rtype: List[List[float]] + """ try: embeds: List[np.ndarray] = list(self._client.embed(docs)) embeddings: List[List[float]] = [e.tolist() for e in embeds] diff --git a/semantic_router/encoders/google.py b/semantic_router/encoders/google.py index 5d50a0e1..5284a747 100644 --- a/semantic_router/encoders/google.py +++ b/semantic_router/encoders/google.py @@ -1,21 +1,3 @@ -""" -This module provides the GoogleEncoder class for generating embeddings using Google's AI Platform. - -The GoogleEncoder class is a subclass of DenseEncoder and utilizes the TextEmbeddingModel from the -Google AI Platform to generate embeddings for given documents. It requires a Google Cloud project ID -and supports customization of the pre-trained model, score threshold, location, and API endpoint. - -Example usage: - - from semantic_router.encoders.google_encoder import GoogleEncoder - - encoder = GoogleEncoder(project_id="your-project-id") - embeddings = encoder(["document1", "document2"]) - -Classes: - GoogleEncoder: A class for generating embeddings using Google's AI Platform. -""" - import os from typing import Any, List, Optional @@ -26,6 +8,19 @@ from semantic_router.utils.defaults import EncoderDefault class GoogleEncoder(DenseEncoder): """GoogleEncoder class for generating embeddings using Google's AI Platform. + The GoogleEncoder class is a subclass of DenseEncoder and utilizes the TextEmbeddingModel from the + Google AI Platform to generate embeddings for given documents. It requires a Google Cloud project ID + and supports customization of the pre-trained model, score threshold, location, and API endpoint. + + Example usage: + + ```python + from semantic_router.encoders.google_encoder import GoogleEncoder + + encoder = GoogleEncoder(project_id="your-project-id") + embeddings = encoder(["document1", "document2"]) + ``` + Attributes: client: An instance of the TextEmbeddingModel client. type: The type of the encoder, which is "google". @@ -44,23 +39,25 @@ class GoogleEncoder(DenseEncoder): ): """Initializes the GoogleEncoder. - Args: - model_name: The name of the pre-trained model to use for embedding. - If not provided, the default model specified in EncoderDefault will - be used. - score_threshold: The threshold for similarity scores. - project_id: The Google Cloud project ID. - If not provided, it will be retrieved from the GOOGLE_PROJECT_ID - environment variable. - location: The location of the AI Platform resources. - If not provided, it will be retrieved from the GOOGLE_LOCATION - environment variable, defaulting to "us-central1". - api_endpoint: The API endpoint for the AI Platform. - If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT - environment variable. - - Raises: - ValueError: If the Google Project ID is not provided or if the AI Platform + :param model_name: The name of the pre-trained model to use for embedding. + If not provided, the default model specified in EncoderDefault will + be used. + :type model_name: str + :param score_threshold: The threshold for similarity scores. + :type score_threshold: float + :param project_id: The Google Cloud project ID. + If not provided, it will be retrieved from the GOOGLE_PROJECT_ID + environment variable. + :type project_id: str + :param location: The location of the AI Platform resources. + If not provided, it will be retrieved from the GOOGLE_LOCATION + environment variable, defaulting to "us-central1". + :type location: str + :param api_endpoint: The API endpoint for the AI Platform. + If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT + environment variable. + :type api_endpoint: str + :raise ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize. """ if name is None: @@ -73,18 +70,17 @@ class GoogleEncoder(DenseEncoder): def _initialize_client(self, project_id, location, api_endpoint): """Initializes the Google AI Platform client. - Args: - project_id: The Google Cloud project ID. - location: The location of the AI Platform resources. - api_endpoint: The API endpoint for the AI Platform. - - Returns: - An instance of the TextEmbeddingModel client. - - Raises: - ImportError: If the required Google Cloud or Vertex AI libraries are not + :param project_id: The Google Cloud project ID. + :type project_id: str + :param location: The location of the AI Platform resources. + :type location: str + :param api_endpoint: The API endpoint for the AI Platform. + :type api_endpoint: str + :return: An instance of the TextEmbeddingModel client. + :rtype: TextEmbeddingModel + :raise ImportError: If the required Google Cloud or Vertex AI libraries are not installed. - ValueError: If the Google Project ID is not provided or if the AI Platform + :raise ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize. """ try: @@ -119,15 +115,12 @@ class GoogleEncoder(DenseEncoder): def __call__(self, docs: List[str]) -> List[List[float]]: """Generates embeddings for the given documents. - Args: - docs: A list of strings representing the documents to embed. - - Returns: - A list of lists, where each inner list contains the embedding values for a + :param docs: A list of strings representing the documents to embed. + :type docs: List[str] + :return: A list of lists, where each inner list contains the embedding values for a document. - - Raises: - ValueError: If the Google AI Platform client is not initialized or if the + :rtype: List[List[float]] + :raise ValueError: If the Google AI Platform client is not initialized or if the API call fails. """ if self.client is None: diff --git a/semantic_router/encoders/huggingface.py b/semantic_router/encoders/huggingface.py index b659ebac..6e4dbc5e 100644 --- a/semantic_router/encoders/huggingface.py +++ b/semantic_router/encoders/huggingface.py @@ -31,7 +31,26 @@ from semantic_router.encoders import DenseEncoder from semantic_router.utils.logger import logger +# TODO: this should support local models, and we should have another class for remote +# inference endpoint models + class HuggingFaceEncoder(DenseEncoder): + """HuggingFace encoder class for local embedding models. Models can be trained and + loaded from private repositories, or from the Huggingface Hub. The class supports + customization of the score threshold for filtering or processing the embeddings. + + Example usage: + + ```python + from semantic_router.encoders import HuggingFaceEncoder + + encoder = HuggingFaceEncoder( + name="sentence-transformers/all-MiniLM-L6-v2", + device="cuda" + ) + embeddings = encoder(["document1", "document2"]) + ``` + """ name: str = "sentence-transformers/all-MiniLM-L6-v2" type: str = "huggingface" tokenizer_kwargs: Dict = {} @@ -92,6 +111,12 @@ class HuggingFaceEncoder(DenseEncoder): normalize_embeddings: bool = True, pooling_strategy: str = "mean", ) -> List[List[float]]: + """Encode a list of documents into embeddings using the local Hugging Face model. + + :param docs: A list of documents to encode. + :type docs: List[str] + :param batch_size: The batch size for encoding. + """ all_embeddings = [] for i in range(0, len(docs), batch_size): batch_docs = docs[i : i + batch_size] @@ -124,6 +149,12 @@ class HuggingFaceEncoder(DenseEncoder): return all_embeddings def _mean_pooling(self, model_output, attention_mask): + """Perform mean pooling on the token embeddings. + + :param model_output: The output of the model. + :type model_output: torch.Tensor + :param attention_mask: The attention mask. + """ token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() @@ -133,6 +164,12 @@ class HuggingFaceEncoder(DenseEncoder): ) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9) def _max_pooling(self, model_output, attention_mask): + """Perform max pooling on the token embeddings. + + :param model_output: The output of the model. + :type model_output: torch.Tensor + :param attention_mask: The attention mask. + """ token_embeddings = model_output[0] input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() @@ -142,13 +179,24 @@ class HuggingFaceEncoder(DenseEncoder): class HFEndpointEncoder(DenseEncoder): - """ - A class to encode documents using a Hugging Face transformer model endpoint. + """HFEndpointEncoder class to embeddings models using Huggingface's inference endpoints. + + The HFEndpointEncoder class is a subclass of DenseEncoder and utilizes a specified + Huggingface endpoint to generate embeddings for given documents. It requires the URL + of the Huggingface API endpoint and an API key for authentication. The class supports + customization of the score threshold for filtering or processing the embeddings. - Attributes: - huggingface_url (str): The URL of the Hugging Face API endpoint. - huggingface_api_key (str): The API key for authenticating with the Hugging Face API. - score_threshold (float): A threshold value used for filtering or processing the embeddings. + Example usage: + + ```python + from semantic_router.encoders import HFEndpointEncoder + + encoder = HFEndpointEncoder( + huggingface_url="https://api-inference.huggingface.co/models/BAAI/bge-large-en-v1.5", + huggingface_api_key="your-hugging-face-api-key" + ) + embeddings = encoder(["document1", "document2"]) + ``` """ name: str = "hugging_face_custom_endpoint" @@ -162,21 +210,17 @@ class HFEndpointEncoder(DenseEncoder): huggingface_api_key: Optional[str] = None, score_threshold: float = 0.8, ): - """ - Initializes the HFEndpointEncoder with the specified parameters. - - Args: - name (str, optional): The name of the encoder. Defaults to - "hugging_face_custom_endpoint". - huggingface_url (str, optional): The URL of the Hugging Face API endpoint. - Cannot be None. - huggingface_api_key (str, optional): The API key for the Hugging Face API. - Cannot be None. - score_threshold (float, optional): A threshold for processing the embeddings. - Defaults to 0.8. - - Raises: - ValueError: If either `huggingface_url` or `huggingface_api_key` is None. + """Initializes the HFEndpointEncoder with the specified parameters. + + :param name: The name of the encoder. + :type name: str + :param huggingface_url: The URL of the Hugging Face API endpoint. + :type huggingface_url: str + :param huggingface_api_key: The API key for the Hugging Face API. + :type huggingface_api_key: str + :param score_threshold: A threshold for processing the embeddings. + :type score_threshold: float + :raise ValueError: If either `huggingface_url` or `huggingface_api_key` is None. """ huggingface_url = huggingface_url or os.getenv("HF_API_URL") huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY") @@ -201,17 +245,13 @@ class HFEndpointEncoder(DenseEncoder): ) from e def __call__(self, docs: List[str]) -> List[List[float]]: - """ - Encodes a list of documents into embeddings using the Hugging Face API. - - Args: - docs (List[str]): A list of documents to encode. - - Returns: - List[List[float]]: A list of embeddings for the given documents. + """Encodes a list of documents into embeddings using the Hugging Face API. - Raises: - ValueError: If no embeddings are returned for a document. + :param docs: A list of documents to encode. + :type docs: List[str] + :return: A list of embeddings for the given documents. + :rtype: List[List[float]] + :raise ValueError: If no embeddings are returned for a document. """ embeddings = [] for d in docs: @@ -228,17 +268,13 @@ class HFEndpointEncoder(DenseEncoder): return embeddings def query(self, payload, max_retries=3, retry_interval=5): - """ - Sends a query to the Hugging Face API and returns the response. - - Args: - payload (dict): The payload to send in the request. - - Returns: - dict: The response from the Hugging Face API. + """Sends a query to the Hugging Face API and returns the response. - Raises: - ValueError: If the query fails or the response status is not 200. + :param payload: The payload to send in the request. + :type payload: dict + :return: The response from the Hugging Face API. + :rtype: dict + :raise ValueError: If the query fails or the response status is not 200. """ headers = { "Accept": "application/json", diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index 6c3a2f5e..f1267a7b 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -11,7 +11,8 @@ from semantic_router.utils.defaults import EncoderDefault class MistralEncoder(DenseEncoder): - """Class to encode text using MistralAI""" + """Class to encode text using MistralAI. Requires a MistralAI API key from + https://console.mistral.ai/api-keys/""" _client: Any = PrivateAttr() _mistralai: Any = PrivateAttr() @@ -23,12 +24,27 @@ class MistralEncoder(DenseEncoder): mistralai_api_key: Optional[str] = None, score_threshold: float = 0.82, ): + """Initialize the MistralEncoder. + + :param name: The name of the embedding model to use. + :type name: str + :param mistralai_api_key: The MistralAI API key. + :type mistralai_api_key: str + :param score_threshold: The score threshold for the embeddings. + """ if name is None: name = EncoderDefault.MISTRAL.value["embedding_model"] super().__init__(name=name, score_threshold=score_threshold) self._client, self._mistralai = self._initialize_client(mistralai_api_key) def _initialize_client(self, api_key): + """Initialize the MistralAI client. + + :param api_key: The MistralAI API key. + :type api_key: str + :return: The MistralAI client. + :rtype: MistralClient + """ try: import mistralai from mistralai.client import MistralClient @@ -49,6 +65,13 @@ class MistralEncoder(DenseEncoder): return client, mistralai def __call__(self, docs: List[str]) -> List[List[float]]: + """Encode a list of documents into embeddings using MistralAI. + + :param docs: The documents to encode. + :type docs: List[str] + :return: The embeddings for the documents. + :rtype: List[List[float]] + """ if self._client is None: raise ValueError("Mistral client not initialized") embeds = None diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 4865cc29..2d808ddf 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -35,6 +35,12 @@ model_configs = { class OpenAIEncoder(DenseEncoder): + """OpenAI encoder class for generating embeddings using OpenAI API. + + The OpenAIEncoder class is a subclass of DenseEncoder and utilizes the OpenAI API + to generate embeddings for given documents. It requires an OpenAI API key and + supports customization of the score threshold for filtering or processing the embeddings. + """ _client: Optional[openai.Client] = PrivateAttr(default=None) _async_client: Optional[openai.AsyncClient] = PrivateAttr(default=None) dimensions: Union[int, NotGiven] = NotGiven() @@ -53,6 +59,23 @@ class OpenAIEncoder(DenseEncoder): dimensions: Union[int, NotGiven] = NotGiven(), max_retries: int = 3, ): + """Initialize the OpenAIEncoder. + + :param name: The name of the embedding model to use. + :type name: str + :param openai_base_url: The base URL for the OpenAI API. + :type openai_base_url: str + :param openai_api_key: The OpenAI API key. + :type openai_api_key: str + :param openai_org_id: The OpenAI organization ID. + :type openai_org_id: str + :param score_threshold: The score threshold for the embeddings. + :type score_threshold: float + :param dimensions: The dimensions of the embeddings. + :type dimensions: int + :param max_retries: The maximum number of retries for the OpenAI API call. + :type max_retries: int + """ if name is None: name = EncoderDefault.OPENAI.value["embedding_model"] if score_threshold is None and name in model_configs: @@ -146,6 +169,13 @@ class OpenAIEncoder(DenseEncoder): return embeddings def _truncate(self, text: str) -> str: + """Truncate a document to the token limit. + + :param text: The document to truncate. + :type text: str + :return: The truncated document. + :rtype: str + """ # we use encode_ordinary as faster equivalent to encode(text, disallowed_special=()) tokens = self._token_encoder.encode_ordinary(text) if len(tokens) > self.token_limit: @@ -159,6 +189,13 @@ class OpenAIEncoder(DenseEncoder): return text async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float]]: + """Encode a list of text documents into embeddings using OpenAI API asynchronously. + + :param docs: List of text documents to encode. + :param truncate: Whether to truncate the documents to token limit. If + False and a document exceeds the token limit, an error will be + raised. + :return: List of embeddings for each document.""" if self._async_client is None: raise ValueError("OpenAI async client is not initialized.") embeds = None diff --git a/semantic_router/encoders/vit.py b/semantic_router/encoders/vit.py index dec768e4..4be220a9 100644 --- a/semantic_router/encoders/vit.py +++ b/semantic_router/encoders/vit.py @@ -6,6 +6,12 @@ from semantic_router.encoders import DenseEncoder class VitEncoder(DenseEncoder): + """Encoder for Vision Transformer models. + + This class provides functionality to encode images using a Vision Transformer + model via Hugging Face. It supports various image processing and model initialization + options. + """ name: str = "google/vit-base-patch16-224" type: str = "huggingface" processor_kwargs: Dict = {} @@ -18,12 +24,22 @@ class VitEncoder(DenseEncoder): _Image: Any = PrivateAttr() def __init__(self, **data): + """Initialize the VitEncoder. + + :param **data: Additional keyword arguments for the encoder. + :type **data: dict + """ if data.get("score_threshold") is None: data["score_threshold"] = 0.5 super().__init__(**data) self._processor, self._model = self._initialize_hf_model() def _initialize_hf_model(self): + """Initialize the Hugging Face model. + + :return: The processor and model. + :rtype: tuple + """ try: from transformers import ViTImageProcessor, ViTModel except ImportError: @@ -68,6 +84,11 @@ class VitEncoder(DenseEncoder): return processor, model def _get_device(self) -> str: + """Get the device to use for the model. + + :return: The device to use for the model. + :rtype: str + """ if self.device: device = self.device elif self._torch.cuda.is_available(): @@ -79,12 +100,26 @@ class VitEncoder(DenseEncoder): return device def _process_images(self, images: List[Any]): + """Process the images for the model. + + :param images: The images to process. + :type images: List[Any] + :return: The processed images. + :rtype: Any + """ rgb_images = [self._ensure_rgb(img) for img in images] processed_images = self._processor(images=rgb_images, return_tensors="pt") processed_images = processed_images.to(self.device) return processed_images def _ensure_rgb(self, img: Any): + """Ensure the image is in RGB format. + + :param img: The image to ensure is in RGB format. + :type img: Any + :return: The image in RGB format. + :rtype: Any + """ rgbimg = self._Image.new("RGB", img.size) rgbimg.paste(img) return rgbimg @@ -94,6 +129,15 @@ class VitEncoder(DenseEncoder): imgs: List[Any], batch_size: int = 32, ) -> List[List[float]]: + """Encode a list of images into embeddings using the Vision Transformer model. + + :param imgs: The images to encode. + :type imgs: List[Any] + :param batch_size: The batch size for encoding. + :type batch_size: int + :return: The embeddings for the images. + :rtype: List[List[float]] + """ all_embeddings = [] for i in range(0, len(imgs), batch_size): batch_imgs = imgs[i : i + batch_size] diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index f16213dc..3822a743 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -14,6 +14,11 @@ from semantic_router.utils.logger import logger class AzureOpenAIEncoder(DenseEncoder): + """Encoder for Azure OpenAI API. + + This class provides functionality to encode text documents using the Azure OpenAI API. + It supports customization of the score threshold for filtering or processing the embeddings. + """ client: Optional[openai.AzureOpenAI] = None async_client: Optional[openai.AsyncAzureOpenAI] = None dimensions: Union[int, NotGiven] = NotGiven() @@ -36,6 +41,25 @@ class AzureOpenAIEncoder(DenseEncoder): dimensions: Union[int, NotGiven] = NotGiven(), max_retries: int = 3, ): + """Initialize the AzureOpenAIEncoder. + + :param api_key: The API key for the Azure OpenAI API. + :type api_key: str + :param deployment_name: The name of the deployment to use. + :type deployment_name: str + :param azure_endpoint: The endpoint for the Azure OpenAI API. + :type azure_endpoint: str + :param api_version: The version of the API to use. + :type api_version: str + :param model: The model to use. + :type model: str + :param score_threshold: The score threshold for the embeddings. + :type score_threshold: float + :param dimensions: The dimensions of the embeddings. + :type dimensions: int + :param max_retries: The maximum number of retries for the API call. + :type max_retries: int + """ name = deployment_name if name is None: name = EncoderDefault.AZURE.value["embedding_model"] @@ -98,6 +122,13 @@ class AzureOpenAIEncoder(DenseEncoder): ) from e def __call__(self, docs: List[str]) -> List[List[float]]: + """Encode a list of documents into embeddings using the Azure OpenAI API. + + :param docs: The documents to encode. + :type docs: List[str] + :return: The embeddings for the documents. + :rtype: List[List[float]] + """ if self.client is None: raise ValueError("Azure OpenAI client is not initialized.") embeds = None @@ -136,10 +167,16 @@ class AzureOpenAIEncoder(DenseEncoder): return embeddings async def acall(self, docs: List[str]) -> List[List[float]]: + """Encode a list of documents into embeddings using the Azure OpenAI API asynchronously. + + :param docs: The documents to encode. + :type docs: List[str] + :return: The embeddings for the documents. + :rtype: List[List[float]] + """ if self.async_client is None: raise ValueError("Azure OpenAI async client is not initialized.") embeds = None - # Exponential backoff for j in range(self.max_retries + 1): try: diff --git a/semantic_router/linear.py b/semantic_router/linear.py index 1c13262f..a308dc5a 100644 --- a/semantic_router/linear.py +++ b/semantic_router/linear.py @@ -7,12 +7,10 @@ from numpy.linalg import norm def similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray: """Compute the similarity scores between a query vector and a set of vectors. - Args: - xq: A query vector (1d ndarray) - index: A set of vectors. - - Returns: - The similarity between the query vector and the set of vectors. + :param xq: A query vector (1d ndarray) + :param index: A set of vectors. + :return: The similarity between the query vector and the set of vectors. + :rtype: np.ndarray """ index_norm = norm(index, axis=1) @@ -22,7 +20,13 @@ def similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray: def top_scores(sim: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]: - # get indices of top_k records + """Get the top scores and indices from a similarity matrix. + + :param sim: A similarity matrix. + :param top_k: The number of top scores to get. + :return: The top scores and indices. + :rtype: Tuple[np.ndarray, np.ndarray] + """ top_k = min(top_k, sim.shape[0]) idx = np.argpartition(sim, -top_k)[-top_k:] scores = sim[idx] diff --git a/semantic_router/llms/base.py b/semantic_router/llms/base.py index 1ebc7aa5..611f1ac6 100644 --- a/semantic_router/llms/base.py +++ b/semantic_router/llms/base.py @@ -8,6 +8,11 @@ from semantic_router.utils.logger import logger class BaseLLM(BaseModel): + """Base class for LLMs typically used by dynamic routes. + + This class provides a base implementation for LLMs. It defines the common + configuration and methods for all LLM classes. + """ name: str temperature: Optional[float] = 0.0 max_tokens: Optional[int] = None @@ -16,15 +21,39 @@ class BaseLLM(BaseModel): arbitrary_types_allowed = True def __init__(self, name: str, **kwargs): + """Initialize the BaseLLM. + + :param name: The name of the LLM. + :type name: str + :param **kwargs: Additional keyword arguments for the LLM. + :type **kwargs: dict + """ super().__init__(name=name, **kwargs) def __call__(self, messages: List[Message]) -> Optional[str]: + """Call the LLM. + + Must be implemented by subclasses. + + :param messages: The messages to pass to the LLM. + :type messages: List[Message] + :return: The response from the LLM. + :rtype: Optional[str] + """ raise NotImplementedError("Subclasses must implement this method") def _check_for_mandatory_inputs( self, inputs: dict[str, Any], mandatory_params: List[str] ) -> bool: - """Check for mandatory parameters in inputs""" + """Check for mandatory parameters in inputs. + + :param inputs: The inputs to check for mandatory parameters. + :type inputs: dict[str, Any] + :param mandatory_params: The mandatory parameters to check for. + :type mandatory_params: List[str] + :return: True if all mandatory parameters are present, False otherwise. + :rtype: bool + """ for name in mandatory_params: if name not in inputs: logger.error(f"Mandatory input {name} missing from query") @@ -34,7 +63,15 @@ class BaseLLM(BaseModel): def _check_for_extra_inputs( self, inputs: dict[str, Any], all_params: List[str] ) -> bool: - """Check for extra parameters not defined in the signature""" + """Check for extra parameters not defined in the signature. + + :param inputs: The inputs to check for extra parameters. + :type inputs: dict[str, Any] + :param all_params: The all parameters to check for. + :type all_params: List[str] + :return: True if all extra parameters are present, False otherwise. + :rtype: bool + """ input_keys = set(inputs.keys()) param_keys = set(all_params) if not input_keys.issubset(param_keys): @@ -49,7 +86,15 @@ class BaseLLM(BaseModel): self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]] ) -> bool: """Determine if the functions chosen by the LLM exist within the function_schemas, - and if the input arguments are valid for those functions.""" + and if the input arguments are valid for those functions. + + :param inputs: The inputs to check for validity. + :type inputs: List[Dict[str, Any]] + :param function_schemas: The function schemas to check against. + :type function_schemas: List[Dict[str, Any]] + :return: True if the inputs are valid, False otherwise. + :rtype: bool + """ try: # Currently only supporting single functions for most LLMs in Dynamic Routes. if len(inputs) != 1: @@ -72,7 +117,15 @@ class BaseLLM(BaseModel): def _validate_single_function_inputs( self, inputs: Dict[str, Any], function_schema: Dict[str, Any] ) -> bool: - """Validate the extracted inputs against the function schema""" + """Validate the extracted inputs against the function schema. + + :param inputs: The inputs to validate. + :type inputs: Dict[str, Any] + :param function_schema: The function schema to validate against. + :type function_schema: Dict[str, Any] + :return: True if the inputs are valid, False otherwise. + :rtype: bool + """ try: # Extract parameter names and determine if they are optional signature = function_schema["signature"] @@ -107,7 +160,13 @@ class BaseLLM(BaseModel): return False def _extract_parameter_info(self, signature: str) -> tuple[List[str], List[str]]: - """Extract parameter names and types from the function signature.""" + """Extract parameter names and types from the function signature. + + :param signature: The function signature to extract parameter names and types from. + :type signature: str + :return: A tuple of parameter names and types. + :rtype: tuple[List[str], List[str]] + """ param_info = [param.strip() for param in signature[1:-1].split(",")] param_names = [info.split(":")[0].strip() for info in param_info] param_types = [ @@ -118,6 +177,15 @@ class BaseLLM(BaseModel): def extract_function_inputs( self, query: str, function_schemas: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: + """Extract the function inputs from the query. + + :param query: The query to extract the function inputs from. + :type query: str + :param function_schemas: The function schemas to extract the function inputs from. + :type function_schemas: List[Dict[str, Any]] + :return: The function inputs. + :rtype: List[Dict[str, Any]] + """ logger.info("Extracting function input...") prompt = f""" diff --git a/semantic_router/llms/cohere.py b/semantic_router/llms/cohere.py index d37f979d..b258ac8f 100644 --- a/semantic_router/llms/cohere.py +++ b/semantic_router/llms/cohere.py @@ -8,6 +8,11 @@ from semantic_router.schema import Message class CohereLLM(BaseLLM): + """LLM for Cohere. Requires a Cohere API key from https://dashboard.cohere.com/api-keys. + + This class provides functionality to interact with the Cohere API for generating text responses. + It extends the BaseLLM class and implements the __call__ method to generate text responses. + """ _client: Any = PrivateAttr() def __init__( @@ -15,12 +20,27 @@ class CohereLLM(BaseLLM): name: Optional[str] = None, cohere_api_key: Optional[str] = None, ): + """Initialize the CohereLLM. + + :param name: The name of the Cohere model to use can also be set via the + COHERE_CHAT_MODEL_NAME environment variable. + :type name: Optional[str] + :param cohere_api_key: The API key for the Cohere client. Can also be set via the + COHERE_API_KEY environment variable. + :type cohere_api_key: Optional[str] + """ if name is None: name = os.getenv("COHERE_CHAT_MODEL_NAME", "command") super().__init__(name=name) self._client = self._initialize_client(cohere_api_key) def _initialize_client(self, cohere_api_key: Optional[str] = None): + """Initialize the Cohere client. + + :param cohere_api_key: The API key for the Cohere client. Can also be set via the + COHERE_API_KEY environment variable. + :type cohere_api_key: Optional[str] + """ try: import cohere except ImportError: @@ -41,6 +61,13 @@ class CohereLLM(BaseLLM): return client def __call__(self, messages: List[Message]) -> str: + """Call the Cohere client. + + :param messages: The messages to pass to the Cohere client. + :type messages: List[Message] + :return: The response from the Cohere client. + :rtype: str + """ if self._client is None: raise ValueError("Cohere client is not initialized.") try: diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index dda11bea..d60f85dd 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -10,6 +10,9 @@ from semantic_router.utils.logger import logger class LlamaCppLLM(BaseLLM): + """LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of + dynamic routes. + """ llm: Any grammar: Optional[Any] = None _llama_cpp: Any = PrivateAttr() @@ -22,6 +25,19 @@ class LlamaCppLLM(BaseLLM): max_tokens: Optional[int] = 200, grammar: Optional[Any] = None, ): + """Initialize the LlamaCPPLLM. + + :param llm: The LLM to use. + :type llm: Any + :param name: The name of the LLM. + :type name: str + :param temperature: The temperature of the LLM. + :type temperature: float + :param max_tokens: The maximum number of tokens to generate. + :type max_tokens: Optional[int] + :param grammar: The grammar to use. + :type grammar: Optional[Any] + """ super().__init__( name=name, llm=llm, @@ -48,6 +64,13 @@ class LlamaCppLLM(BaseLLM): self, messages: List[Message], ) -> str: + """Call the LlamaCPPLLM. + + :param messages: The messages to pass to the LlamaCPPLLM. + :type messages: List[Message] + :return: The response from the LlamaCPPLLM. + :rtype: str + """ try: completion = self.llm.create_chat_completion( messages=[m.to_llamacpp() for m in messages], @@ -68,6 +91,11 @@ class LlamaCppLLM(BaseLLM): @contextmanager def _grammar(self): + """Context manager for the grammar. + + :return: The grammar. + :rtype: Any + """ grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf") assert grammar_path.exists(), f"{grammar_path}\ndoes not exist" try: @@ -79,6 +107,15 @@ class LlamaCppLLM(BaseLLM): def extract_function_inputs( self, query: str, function_schemas: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: + """Extract the function inputs from the query. + + :param query: The query to extract the function inputs from. + :type query: str + :param function_schemas: The function schemas to extract the function inputs from. + :type function_schemas: List[Dict[str, Any]] + :return: The function inputs. + :rtype: List[Dict[str, Any]] + """ with self._grammar(): return super().extract_function_inputs( query=query, function_schemas=function_schemas diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 370fba65..0d827a58 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -10,6 +10,8 @@ from semantic_router.utils.logger import logger class MistralAILLM(BaseLLM): + """LLM for MistralAI. Requires a MistralAI API key from https://console.mistral.ai/api-keys/ + """ _client: Any = PrivateAttr() _mistralai: Any = PrivateAttr() @@ -20,6 +22,17 @@ class MistralAILLM(BaseLLM): temperature: float = 0.01, max_tokens: int = 200, ): + """Initialize the MistralAILLM. + + :param name: The name of the MistralAI model to use. + :type name: Optional[str] + :param mistralai_api_key: The MistralAI API key. + :type mistralai_api_key: Optional[str] + :param temperature: The temperature of the LLM. + :type temperature: float + :param max_tokens: The maximum number of tokens to generate. + :type max_tokens: int + """ if name is None: name = EncoderDefault.MISTRAL.value["language_model"] super().__init__(name=name) @@ -28,6 +41,13 @@ class MistralAILLM(BaseLLM): self.max_tokens = max_tokens def _initialize_client(self, api_key): + """Initialize the MistralAI client. + + :param api_key: The MistralAI API key. + :type api_key: Optional[str] + :return: The MistralAI client. + :rtype: MistralClient + """ try: import mistralai from mistralai.client import MistralClient @@ -49,9 +69,15 @@ class MistralAILLM(BaseLLM): return client, mistralai def __call__(self, messages: List[Message]) -> str: + """Call the MistralAILLM. + + :param messages: The messages to pass to the MistralAILLM. + :type messages: List[Message] + :return: The response from the MistralAILLM. + :rtype: str + """ if self._client is None: raise ValueError("MistralAI client is not initialized.") - chat_messages = [ self._mistralai.models.chat_completion.ChatMessage( role=m.role, content=m.content diff --git a/semantic_router/llms/ollama.py b/semantic_router/llms/ollama.py index f6e9779e..5a29a360 100644 --- a/semantic_router/llms/ollama.py +++ b/semantic_router/llms/ollama.py @@ -8,6 +8,9 @@ from semantic_router.utils.logger import logger class OllamaLLM(BaseLLM): + """LLM for Ollama. Enables fully local LLM use, helpful for local implementation of + dynamic routes. + """ stream: bool = False def __init__( @@ -17,6 +20,17 @@ class OllamaLLM(BaseLLM): max_tokens: Optional[int] = 200, stream: bool = False, ): + """Initialize the OllamaLLM. + + :param name: The name of the Ollama model to use. + :type name: str + :param temperature: The temperature of the LLM. + :type temperature: float + :param max_tokens: The maximum number of tokens to generate. + :type max_tokens: Optional[int] + :param stream: Whether to stream the response. + :type stream: bool + """ super().__init__(name=name) self.temperature = temperature self.max_tokens = max_tokens @@ -30,6 +44,19 @@ class OllamaLLM(BaseLLM): max_tokens: Optional[int] = None, stream: Optional[bool] = None, ) -> str: + """Call the OllamaLLM. + + :param messages: The messages to pass to the OllamaLLM. + :type messages: List[Message] + :param temperature: The temperature of the LLM. + :type temperature: Optional[float] + :param name: The name of the Ollama model to use. + :type name: Optional[str] + :param max_tokens: The maximum number of tokens to generate. + :type max_tokens: Optional[int] + :param stream: Whether to stream the response. + :type stream: Optional[bool] + """ # Use instance defaults if not overridden temperature = temperature if temperature is not None else self.temperature name = name if name is not None else self.name diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 3e991148..71cc3322 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -22,6 +22,8 @@ from semantic_router.utils.logger import logger class OpenAILLM(BaseLLM): + """LLM for OpenAI. Requires an OpenAI API key from https://platform.openai.com/api-keys. + """ _client: Optional[openai.OpenAI] = PrivateAttr(default=None) _async_client: Optional[openai.AsyncOpenAI] = PrivateAttr(default=None) @@ -32,6 +34,17 @@ class OpenAILLM(BaseLLM): temperature: float = 0.01, max_tokens: int = 200, ): + """Initialize the OpenAILLM. + + :param name: The name of the OpenAI model to use. + :type name: Optional[str] + :param openai_api_key: The OpenAI API key. + :type openai_api_key: Optional[str] + :param temperature: The temperature of the LLM. + :type temperature: float + :param max_tokens: The maximum number of tokens to generate. + :type max_tokens: int + """ if name is None: name = EncoderDefault.OPENAI.value["language_model"] super().__init__(name=name) @@ -51,6 +64,13 @@ class OpenAILLM(BaseLLM): def _extract_tool_calls_info( self, tool_calls: List[ChatCompletionMessageToolCall] ) -> List[Dict[str, Any]]: + """Extract the tool calls information from the tool calls. + + :param tool_calls: The tool calls to extract the information from. + :type tool_calls: List[ChatCompletionMessageToolCall] + :return: The tool calls information. + :rtype: List[Dict[str, Any]] + """ tool_calls_info = [] for tool_call in tool_calls: if tool_call.function.arguments is None: @@ -68,6 +88,13 @@ class OpenAILLM(BaseLLM): async def async_extract_tool_calls_info( self, tool_calls: List[ChatCompletionMessageToolCall] ) -> List[Dict[str, Any]]: + """Extract the tool calls information from the tool calls. + + :param tool_calls: The tool calls to extract the information from. + :type tool_calls: List[ChatCompletionMessageToolCall] + :return: The tool calls information. + :rtype: List[Dict[str, Any]] + """ tool_calls_info = [] for tool_call in tool_calls: if tool_call.function.arguments is None: @@ -87,6 +114,15 @@ class OpenAILLM(BaseLLM): messages: List[Message], function_schemas: Optional[List[Dict[str, Any]]] = None, ) -> str: + """Call the OpenAILLM. + + :param messages: The messages to pass to the OpenAILLM. + :type messages: List[Message] + :param function_schemas: The function schemas to pass to the OpenAILLM. + :type function_schemas: Optional[List[Dict[str, Any]]] + :return: The response from the OpenAILLM. + :rtype: str + """ if self._client is None: raise ValueError("OpenAI client is not initialized.") try: @@ -131,6 +167,15 @@ class OpenAILLM(BaseLLM): messages: List[Message], function_schemas: Optional[List[Dict[str, Any]]] = None, ) -> str: + """Call the OpenAILLM asynchronously. + + :param messages: The messages to pass to the OpenAILLM. + :type messages: List[Message] + :param function_schemas: The function schemas to pass to the OpenAILLM. + :type function_schemas: Optional[List[Dict[str, Any]]] + :return: The response from the OpenAILLM. + :rtype: str + """ if self._async_client is None: raise ValueError("OpenAI async_client is not initialized.") try: @@ -173,6 +218,15 @@ class OpenAILLM(BaseLLM): def extract_function_inputs( self, query: str, function_schemas: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: + """Extract the function inputs from the query. + + :param query: The query to extract the function inputs from. + :type query: str + :param function_schemas: The function schemas to extract the function inputs from. + :type function_schemas: List[Dict[str, Any]] + :return: The function inputs. + :rtype: List[Dict[str, Any]] + """ system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages = [ Message(role="system", content=system_prompt), @@ -190,6 +244,15 @@ class OpenAILLM(BaseLLM): async def async_extract_function_inputs( self, query: str, function_schemas: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: + """Extract the function inputs from the query asynchronously. + + :param query: The query to extract the function inputs from. + :type query: str + :param function_schemas: The function schemas to extract the function inputs from. + :type function_schemas: List[Dict[str, Any]] + :return: The function inputs. + :rtype: List[Dict[str, Any]] + """ system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." messages = [ Message(role="system", content=system_prompt), @@ -208,7 +271,15 @@ class OpenAILLM(BaseLLM): self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]] ) -> bool: """Determine if the functions chosen by the LLM exist within the function_schemas, - and if the input arguments are valid for those functions.""" + and if the input arguments are valid for those functions. + + :param inputs: The inputs to check for validity. + :type inputs: List[Dict[str, Any]] + :param function_schemas: The function schemas to check against. + :type function_schemas: List[Dict[str, Any]] + :return: True if the inputs are valid, False otherwise. + :rtype: bool + """ try: for input_dict in inputs: # Check if 'function_name' and 'arguments' keys exist in each input dictionary @@ -251,7 +322,14 @@ class OpenAILLM(BaseLLM): def _validate_single_function_inputs( self, inputs: Dict[str, Any], function_schema: Dict[str, Any] ) -> bool: - """Validate the extracted inputs against the function schema""" + """Validate the extracted inputs against the function schema. + + :param inputs: The inputs to validate. + :type inputs: Dict[str, Any] + :param function_schema: The function schema to validate against. + :type function_schema: Dict[str, Any] + :return: True if the inputs are valid, False otherwise. + """ try: # Access the parameters and their properties from the function schema directly parameters = function_schema["parameters"]["properties"] @@ -283,6 +361,13 @@ class OpenAILLM(BaseLLM): def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]: + """Get function schemas for the OpenAI LLM from a list of functions. + + :param items: The functions to get function schemas for. + :type items: List[Callable] + :return: The schemas for the OpenAI LLM. + :rtype: List[Dict[str, Any]] + """ schemas = [] for item in items: if not callable(item): diff --git a/semantic_router/llms/openrouter.py b/semantic_router/llms/openrouter.py index 34dc147d..b1a5a558 100644 --- a/semantic_router/llms/openrouter.py +++ b/semantic_router/llms/openrouter.py @@ -10,6 +10,8 @@ from semantic_router.utils.logger import logger class OpenRouterLLM(BaseLLM): + """LLM for OpenRouter. Requires an OpenRouter API key, see here for more information + https://openrouter.ai/docs/api-reference/authentication#using-an-api-key""" _client: Optional[openai.OpenAI] = PrivateAttr(default=None) _base_url: str = PrivateAttr(default="https://openrouter.ai/api/v1") @@ -21,6 +23,19 @@ class OpenRouterLLM(BaseLLM): temperature: float = 0.01, max_tokens: int = 200, ): + """Initialize the OpenRouterLLM. + + :param name: The name of the OpenRouter model to use. + :type name: Optional[str] + :param openrouter_api_key: The OpenRouter API key. + :type openrouter_api_key: Optional[str] + :param base_url: The base URL for the OpenRouter API. + :type base_url: str + :param temperature: The temperature of the LLM. + :type temperature: float + :param max_tokens: The maximum number of tokens to generate. + :type max_tokens: int + """ if name is None: name = os.getenv( "OPENROUTER_CHAT_MODEL_NAME", "mistralai/mistral-7b-instruct" @@ -40,6 +55,13 @@ class OpenRouterLLM(BaseLLM): self.max_tokens = max_tokens def __call__(self, messages: List[Message]) -> str: + """Call the OpenRouterLLM. + + :param messages: The messages to pass to the OpenRouterLLM. + :type messages: List[Message] + :return: The response from the OpenRouterLLM. + :rtype: str + """ if self._client is None: raise ValueError("OpenRouter client is not initialized.") try: diff --git a/semantic_router/llms/zure.py b/semantic_router/llms/zure.py index fae5149d..3512e89a 100644 --- a/semantic_router/llms/zure.py +++ b/semantic_router/llms/zure.py @@ -11,6 +11,8 @@ from semantic_router.utils.logger import logger class AzureOpenAILLM(BaseLLM): + """LLM for Azure OpenAI. Requires an Azure OpenAI API key. + """ _client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None) def __init__( @@ -22,6 +24,21 @@ class AzureOpenAILLM(BaseLLM): max_tokens: int = 200, api_version="2023-07-01-preview", ): + """Initialize the AzureOpenAILLM. + + :param name: The name of the Azure OpenAI model to use. + :type name: Optional[str] + :param openai_api_key: The Azure OpenAI API key. + :type openai_api_key: Optional[str] + :param azure_endpoint: The Azure OpenAI endpoint. + :type azure_endpoint: Optional[str] + :param temperature: The temperature of the LLM. + :type temperature: float + :param max_tokens: The maximum number of tokens to generate. + :type max_tokens: int + :param api_version: The API version to use. + :type api_version: str + """ if name is None: name = EncoderDefault.AZURE.value["language_model"] super().__init__(name=name) @@ -41,6 +58,13 @@ class AzureOpenAILLM(BaseLLM): self.max_tokens = max_tokens def __call__(self, messages: List[Message]) -> str: + """Call the AzureOpenAILLM. + + :param messages: The messages to pass to the AzureOpenAILLM. + :type messages: List[Message] + :return: The response from the AzureOpenAILLM. + :rtype: str + """ if self._client is None: raise ValueError("AzureOpenAI client is not initialized.") try: diff --git a/semantic_router/route.py b/semantic_router/route.py index b1fc5e8b..4294cc72 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -11,6 +11,13 @@ from semantic_router.utils.logger import logger def is_valid(route_config: str) -> bool: + """Check if the route config is valid. + + :param route_config: The route config to check. + :type route_config: str + :return: Whether the route config is valid. + :rtype: bool + """ try: output_json = json.loads(route_config) required_keys = ["name", "utterances"] @@ -39,18 +46,36 @@ def is_valid(route_config: str) -> bool: class Route(BaseModel): - name: str - utterances: Union[List[str], List[Any]] - description: Optional[str] = None - function_schemas: Optional[List[Dict[str, Any]]] = None - llm: Optional[BaseLLM] = None - score_threshold: Optional[float] = None - metadata: Optional[Dict[str, Any]] = {} + """A route for the semantic router. + + :param name: The name of the route. + :type name: str + :param utterances: The utterances of the route. + :type utterances: Union[List[str], List[Any]] + :param description: The description of the route. + :type description: Optional[str] + :param function_schemas: The function schemas of the route. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param llm: The LLM to use. + :type llm: Optional[BaseLLM] + :param score_threshold: The score threshold of the route. + :type score_threshold: Optional[float] + :param metadata: The metadata of the route. + :type metadata: Optional[Dict[str, Any]] + """ class Config: arbitrary_types_allowed = True def __call__(self, query: Optional[str] = None) -> RouteChoice: + """Call the route. If dynamic routes have been provided the query must have been + provided and the llm attribute must be set. + + :param query: The query to pass to the route. + :type query: Optional[str] + :return: The route choice. + :rtype: RouteChoice + """ if self.function_schemas: if not self.llm: raise ValueError( @@ -73,6 +98,14 @@ class Route(BaseModel): return RouteChoice(name=self.name, function_call=func_call) async def acall(self, query: Optional[str] = None) -> RouteChoice: + """Asynchronous call the route. If dynamic routes have been provided the query + must have been provided and the llm attribute must be set. + + :param query: The query to pass to the route. + :type query: Optional[str] + :return: The route choice. + :rtype: RouteChoice + """ if self.function_schemas: if not self.llm: raise ValueError( @@ -95,6 +128,11 @@ class Route(BaseModel): return RouteChoice(name=self.name, function_call=func_call) def to_dict(self) -> Dict[str, Any]: + """Convert the route to a dictionary. + + :return: The dictionary representation of the route. + :rtype: Dict[str, Any] + """ data = self.dict() if self.llm is not None: data["llm"] = { @@ -106,14 +144,27 @@ class Route(BaseModel): @classmethod def from_dict(cls, data: Dict[str, Any]): + """Create a Route object from a dictionary. + + :param data: The dictionary to create the route from. + :type data: Dict[str, Any] + :return: The created route. + :rtype: Route + """ return cls(**data) @classmethod def from_dynamic_route( cls, llm: BaseLLM, entities: List[Union[BaseModel, Callable]], route_name: str ): - """ - Generate a dynamic Route object from a list of functions or Pydantic models using LLM + """Generate a dynamic Route object from a list of functions or Pydantic models + using an LLM. + + :param llm: The LLM to use. + :type llm: BaseLLM + :param entities: The entities to use. + :type entities: List[Union[BaseModel, Callable]] + :param route_name: The name of the route. """ schemas = function_call.get_schema_list(items=entities) dynamic_route = cls._generate_dynamic_route( @@ -124,6 +175,14 @@ class Route(BaseModel): @classmethod def _parse_route_config(cls, config: str) -> str: + """Parse the route config from the LLM output using regex. Expects the output + content to be wrapped in <config></config> tags. + + :param config: The LLM output. + :type config: str + :return: The parsed route config. + :rtype: str + """ # Regular expression to match content inside <config></config> config_pattern = r"<config>(.*?)</config>" match = re.search(config_pattern, config, re.DOTALL) @@ -138,8 +197,14 @@ class Route(BaseModel): def _generate_dynamic_route( cls, llm: BaseLLM, function_schemas: List[Dict[str, Any]], route_name: str ): - logger.info("Generating dynamic route...") + """Generate a dynamic Route object from a list of function schemas using an LLM. + :param llm: The LLM to use. + :type llm: BaseLLM + :param function_schemas: The function schemas to use. + :type function_schemas: List[Dict[str, Any]] + :param route_name: The name of the route. + """ formatted_schemas = "\n".join( [json.dumps(schema, indent=4) for schema in function_schemas] ) @@ -176,8 +241,6 @@ class Route(BaseModel): route_config = cls._parse_route_config(config=output) - logger.info(f"Generated route config:\n{route_config}") - if is_valid(route_config): route_config_dict = json.loads(route_config) route_config_dict["llm"] = llm diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 9b973cb6..6db78b31 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -12,6 +12,9 @@ from aurelio_sdk.schema import SparseEmbedding as BM25SparseEmbedding class EncoderType(Enum): + """The type of encoder. + """ + AURELIO = "aurelio" AZURE = "azure" COHERE = "cohere" @@ -28,40 +31,59 @@ class EncoderType(Enum): class EncoderInfo(BaseModel): + """Information about an encoder. + """ name: str token_limit: int threshold: Optional[float] = None class RouteChoice(BaseModel): + """A route choice typically output by the routers. + """ name: Optional[str] = None function_call: Optional[List[Dict]] = None similarity_score: Optional[float] = None class Message(BaseModel): + """A message in a conversation, includes the role and content fields. + """ role: str content: str def to_openai(self): - if self.role.lower() not in ["user", "assistant", "system"]: - raise ValueError("Role must be either 'user', 'assistant' or 'system'") + """Convert the message to an OpenAI-compatible format. + """ + if self.role.lower() not in ["user", "assistant", "system", "tool"]: + raise ValueError("Role must be either 'user', 'assistant', 'system' or 'tool'") return {"role": self.role, "content": self.content} def to_cohere(self): + """Convert the message to a Cohere-compatible format. + """ return {"role": self.role, "message": self.content} def to_llamacpp(self): + """Convert the message to a LlamaCPP-compatible format. + """ return {"role": self.role, "content": self.content} def to_mistral(self): + """Convert the message to a Mistral-compatible format. + """ return {"role": self.role, "content": self.content} def __str__(self): + """Convert the message to a string. + """ return f"{self.role}: {self.content}" class ConfigParameter(BaseModel): + """A configuration parameter for a route. Used for remote router metadata such as + router hashes, sync locks, etc. + """ field: str value: str scope: Optional[str] = None @@ -70,6 +92,15 @@ class ConfigParameter(BaseModel): ) def to_pinecone(self, dimensions: int): + """Convert the configuration parameter to a Pinecone-compatible format. Should + be used when upserting configuration parameters to a separate config namespace + within your Pinecone index. + + :param dimensions: The dimensions of the Pinecone index. + :type dimensions: int + :return: A Pinecone-compatible configuration parameter. + :rtype: dict + """ namespace = self.scope or "" return { "id": f"{self.field}#{namespace}", @@ -84,6 +115,9 @@ class ConfigParameter(BaseModel): class Utterance(BaseModel): + """An utterance in a conversation, includes the route, utterance, function + schemas, metadata, and diff tag. + """ route: str utterance: Union[str, Any] function_schemas: Optional[List[Dict]] = None @@ -129,6 +163,14 @@ class Utterance(BaseModel): ) def to_str(self, include_metadata: bool = False): + """Convert an Utterance object to a string. Used for comparisons during sync + check operations. + + :param include_metadata: Whether to include metadata in the string. + :type include_metadata: bool + :return: A string representation of the Utterance object. + :rtype: str + """ if include_metadata: # we sort the dicts to ensure consistent order as we need this to compare # stringified function schemas accurately @@ -149,8 +191,7 @@ class Utterance(BaseModel): class SyncMode(Enum): - """Synchronization modes for local (route layer) and remote (index) - instances. + """Synchronization modes for local (route layer) and remote (index) instances. """ ERROR = "error" @@ -165,12 +206,22 @@ SYNC_MODES = [x.value for x in SyncMode] class UtteranceDiff(BaseModel): + """A list of Utterance objects that represent the differences between local and + remote utterances. + """ diff: List[Utterance] @classmethod def from_utterances( cls, local_utterances: List[Utterance], remote_utterances: List[Utterance] ): + """Create a UtteranceDiff object from two lists of Utterance objects. + + :param local_utterances: A list of Utterance objects. + :type local_utterances: List[Utterance] + :param remote_utterances: A list of Utterance objects. + :type remote_utterances: List[Utterance] + """ local_utterances_map = { x.to_str(include_metadata=True): x for x in local_utterances } @@ -222,14 +273,18 @@ class UtteranceDiff(BaseModel): This diff tells us that the remote has "route2: utterance3" and "route2: utterance4", which do not exist locally. + + :param include_metadata: Whether to include metadata in the string. + :type include_metadata: bool + :return: A list of diff strings. + :rtype: List[str] """ return [x.to_diff_str(include_metadata=include_metadata) for x in self.diff] def get_tag(self, diff_tag: str) -> List[Utterance]: """Get all utterances with a given diff tag. - :param diff_tag: The diff tag to filter by. Must be one of "+", "-", or - " ". + :param diff_tag: The diff tag to filter by. Must be one of "+", "-", or " ". :type diff_tag: str :return: A list of Utterance objects. :rtype: List[Utterance] @@ -239,8 +294,7 @@ class UtteranceDiff(BaseModel): return [x for x in self.diff if x.diff_tag == diff_tag] def get_sync_strategy(self, sync_mode: str) -> dict: - """Generates the optimal synchronization plan for local and remote - instances. + """Generates the optimal synchronization plan for local and remote instances. :param sync_mode: The mode to sync the routes with the remote index. :type sync_mode: str @@ -417,6 +471,8 @@ class UtteranceDiff(BaseModel): class Metric(Enum): + """The metric to use in vector-based similarity search indexes. + """ COSINE = "cosine" DOTPRODUCT = "dotproduct" EUCLIDEAN = "euclidean" @@ -435,6 +491,13 @@ class SparseEmbedding(BaseModel): @classmethod def from_compact_array(cls, array: np.ndarray): + """Create a SparseEmbedding object from a compact array. + + :param array: A compact array. + :type array: np.ndarray + :return: A SparseEmbedding object. + :rtype: SparseEmbedding + """ if array.ndim != 2 or array.shape[1] != 2: raise ValueError( f"Expected a 2D array with 2 columns, got a {array.ndim}D array with {array.shape[1]} columns. " @@ -444,32 +507,69 @@ class SparseEmbedding(BaseModel): @classmethod def from_vector(cls, vector: np.ndarray): - """Consumes an array of sparse vectors containing zero-values.""" + """Consumes an array of sparse vectors containing zero-values. + + :param vector: A sparse vector. + :type vector: np.ndarray + :return: A SparseEmbedding object. + :rtype: SparseEmbedding + """ if vector.ndim != 1: raise ValueError(f"Expected a 1D array, got a {vector.ndim}D array.") return cls.from_compact_array(np.array([np.arange(len(vector)), vector]).T) @classmethod def from_aurelio(cls, embedding: BM25SparseEmbedding): + """Create a SparseEmbedding object from an AurelioSparseEmbedding object. + + :param embedding: An AurelioSparseEmbedding object. + :type embedding: BM25SparseEmbedding + :return: A SparseEmbedding object. + :rtype: SparseEmbedding + """ arr = np.array([embedding.indices, embedding.values]).T return cls.from_compact_array(arr) @classmethod def from_dict(cls, sparse_dict: dict): + """Create a SparseEmbedding object from a dictionary. + + :param sparse_dict: A dictionary of sparse values. + :type sparse_dict: dict + :return: A SparseEmbedding object. + :rtype: SparseEmbedding + """ arr = np.array([list(sparse_dict.keys()), list(sparse_dict.values())]).T return cls.from_compact_array(arr) @classmethod def from_pinecone_dict(cls, sparse_dict: dict): + """Create a SparseEmbedding object from a Pinecone dictionary. + + :param sparse_dict: A Pinecone dictionary. + :type sparse_dict: dict + :return: A SparseEmbedding object. + :rtype: SparseEmbedding + """ arr = np.array([sparse_dict["indices"], sparse_dict["values"]]).T return cls.from_compact_array(arr) def to_dict(self): + """Convert a SparseEmbedding object to a dictionary. + + :return: A dictionary of sparse values. + :rtype: dict + """ return { i: v for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1]) } def to_pinecone(self): + """Convert a SparseEmbedding object to a Pinecone dictionary. + + :return: A Pinecone dictionary. + :rtype: dict + """ return { "indices": self.embedding[:, 0].astype(int).tolist(), "values": self.embedding[:, 1].tolist(), @@ -477,6 +577,11 @@ class SparseEmbedding(BaseModel): # dictionary interface def items(self): + """Return a list of (index, value) tuples from the SparseEmbedding object. + + :return: A list of (index, value) tuples. + :rtype: list + """ return [ (i, v) for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1]) diff --git a/semantic_router/utils/defaults.py b/semantic_router/utils/defaults.py index 151a9935..90bfb1b0 100644 --- a/semantic_router/utils/defaults.py +++ b/semantic_router/utils/defaults.py @@ -3,6 +3,9 @@ from enum import Enum class EncoderDefault(Enum): + """Default model names for each encoder type. + """ + FASTEMBED = { "embedding_model": "BAAI/bge-small-en-v1.5", "language_model": "BAAI/bge-small-en-v1.5", diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index e9c0afd1..8fd450c0 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -9,6 +9,20 @@ from semantic_router.utils.logger import logger class Parameter(BaseModel): + """Parameter for a function. + + :param name: The name of the parameter. + :type name: str + :param description: The description of the parameter. + :type description: Optional[str] + :param type: The type of the parameter. + :type type: str + :param default: The default value of the parameter. + :type default: Any + :param required: Whether the parameter is required. + :type required: bool + """ + class Config: arbitrary_types_allowed = True @@ -21,6 +35,11 @@ class Parameter(BaseModel): required: bool = Field(description="Whether the parameter is required") def to_ollama(self): + """Convert the parameter to a dictionary for an Ollama-compatible function schema. + + :return: The parameter in dictionary format. + :rtype: Dict[str, Any] + """ return { self.name: { "description": self.description, @@ -41,6 +60,11 @@ class FunctionSchema: parameters: List[Parameter] = Field(description="The parameters of the function") def __init__(self, function: Union[Callable, BaseModel]): + """Initialize the FunctionSchema. + + :param function: The function to consume. + :type function: Union[Callable, BaseModel] + """ self.function = function if callable(function): self._process_function(function) @@ -50,6 +74,11 @@ class FunctionSchema: raise TypeError("Function must be a Callable or BaseModel") def _process_function(self, function: Callable): + """Process the function to get the name, description, signature, and output. + + :param function: The function to process. + :type function: Callable + """ self.name = function.__name__ self.description = str(inspect.getdoc(function)) self.signature = str(inspect.signature(function)) @@ -67,6 +96,11 @@ class FunctionSchema: self.parameters = parameters def to_ollama(self): + """Convert the FunctionSchema to an Ollama-compatible function schema dictionary. + + :return: The function schema in dictionary format. + :rtype: Dict[str, Any] + """ schema_dict = { "type": "function", "function": { @@ -94,6 +128,13 @@ class FunctionSchema: return schema_dict def _ollama_type_mapping(self, param_type: str) -> str: + """Map the parameter type to an Ollama-compatible type. + + :param param_type: The type of the parameter. + :type param_type: str + :return: The Ollama-compatible type. + :rtype: str + """ if param_type == "int": return "number" elif param_type == "float": @@ -107,6 +148,13 @@ class FunctionSchema: def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, Any]]: + """Get a list of function schemas from a list of functions or Pydantic BaseModels. + + :param items: The functions or BaseModels to get the schemas for. + :type items: List[Union[BaseModel, Callable]] + :return: A list of function schemas. + :rtype: List[Dict[str, Any]] + """ schemas = [] for item in items: schema = get_schema(item) @@ -115,6 +163,13 @@ def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, A def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: + """Get a function schema from a function or Pydantic BaseModel. + + :param item: The function or BaseModel to get the schema for. + :type item: Union[BaseModel, Callable] + :return: The function schema. + :rtype: Dict[str, Any] + """ if isinstance(item, BaseModel): signature_parts = [] for field_name, field_model in item.__annotations__.items(): @@ -147,6 +202,13 @@ def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]: def convert_python_type_to_json_type(param_type: str) -> str: + """Convert a Python type to a JSON type. + + :param param_type: The type of the parameter. + :type param_type: str + :return: The JSON type. + :rtype: str + """ if param_type == "int": return "number" if param_type == "float": @@ -167,6 +229,19 @@ def convert_python_type_to_json_type(param_type: str) -> str: async def route_and_execute( query: str, llm: BaseLLM, functions: List[Callable], layer ) -> Any: + """Route and execute a function. + + :param query: The query to route and execute. + :type query: str + :param llm: The LLM to use. + :type llm: BaseLLM + :param functions: The functions to execute. + :type functions: List[Callable] + :param layer: The layer to use. + :type layer: Layer + :return: The result of the function. + :rtype: Any + """ route_choice: RouteChoice = layer(query) for function in functions: diff --git a/semantic_router/utils/llm.py b/semantic_router/utils/llm.py deleted file mode 100644 index 5402e47f..00000000 --- a/semantic_router/utils/llm.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -from typing import Optional - -import openai - -from semantic_router.utils.logger import logger - - -def llm(prompt: str) -> Optional[str]: - try: - client = openai.OpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=os.getenv("OPENROUTER_API_KEY"), - ) - - completion = client.chat.completions.create( - model="mistralai/mistral-7b-instruct", - messages=[ - { - "role": "user", - "content": prompt, - }, - ], - temperature=0.01, - max_tokens=200, - ) - - output = completion.choices[0].message.content - - if not output: - raise Exception("No output generated") - return output - except Exception as e: - logger.error(f"LLM error: {e}") - raise Exception(f"LLM error: {e}") from e - - -# TODO integrate async LLM function -# async def allm(prompt: str) -> Optional[str]: -# try: -# client = openai.AsyncOpenAI( -# base_url="https://openrouter.ai/api/v1", -# api_key=os.getenv("OPENROUTER_API_KEY"), -# ) - -# completion = await client.chat.completions.create( -# model="mistralai/mistral-7b-instruct", -# messages=[ -# { -# "role": "user", -# "content": prompt, -# }, -# ], -# temperature=0.01, -# max_tokens=200, -# ) - -# output = completion.choices[0].message.content - -# if not output: -# raise Exception("No output generated") -# return output -# except Exception as e: -# logger.error(f"LLM error: {e}") -# raise Exception(f"LLM error: {e}") from e diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py index 2c1980d8..634dbc8c 100644 --- a/semantic_router/utils/logger.py +++ b/semantic_router/utils/logger.py @@ -4,6 +4,9 @@ import colorlog class CustomFormatter(colorlog.ColoredFormatter): + """Custom formatter for the logger. + """ + def __init__(self): super().__init__( "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s", @@ -21,6 +24,8 @@ class CustomFormatter(colorlog.ColoredFormatter): def add_coloured_handler(logger): + """Add a coloured handler to the logger. + """ formatter = CustomFormatter() console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) @@ -29,6 +34,8 @@ def add_coloured_handler(logger): def setup_custom_logger(name): + """Setup a custom logger. + """ logger = logging.getLogger(name) if not logger.hasHandlers(): -- GitLab