Skip to content
Snippets Groups Projects
Commit bcaa22ec authored by James Briggs's avatar James Briggs
Browse files

feat: further docstrings and cleanup

parent f16f620f
No related branches found
No related tags found
No related merge requests found
Showing
with 844 additions and 136 deletions
......@@ -8,6 +8,9 @@ from semantic_router.utils.defaults import EncoderDefault
class CohereEncoder(DenseEncoder):
"""Dense encoder that uses Cohere API to embed documents. Supports text only. Requires
a Cohere API key from https://dashboard.cohere.com/api-keys.
"""
_client: Any = PrivateAttr()
_embed_type: Any = PrivateAttr()
type: str = "cohere"
......@@ -20,6 +23,18 @@ class CohereEncoder(DenseEncoder):
score_threshold: float = 0.3,
input_type: Optional[str] = "search_query",
):
"""Initialize the Cohere encoder.
:param name: The name of the embedding model to use.
:type name: str
:param cohere_api_key: The API key for the Cohere client, can also
be set via the COHERE_API_KEY environment variable.
:type cohere_api_key: str
:param score_threshold: The threshold for the score of the embedding.
:type score_threshold: float
:param input_type: The type of input to embed.
:type input_type: str
"""
if name is None:
name = EncoderDefault.COHERE.value["embedding_model"]
super().__init__(
......@@ -34,9 +49,10 @@ class CohereEncoder(DenseEncoder):
"""Initializes the Cohere client.
:param cohere_api_key: The API key for the Cohere client, can also
be set via the COHERE_API_KEY environment variable.
be set via the COHERE_API_KEY environment variable.
:type cohere_api_key: str
:return: An instance of the Cohere client.
:rtype: cohere.Client
"""
try:
import cohere
......@@ -61,6 +77,13 @@ class CohereEncoder(DenseEncoder):
return client
def __call__(self, docs: List[str]) -> List[List[float]]:
"""Embed a list of documents. Supports text only.
:param docs: The documents to embed.
:type docs: List[str]
:return: The vector embeddings of the documents.
:rtype: List[List[float]]
"""
if self._client is None:
raise ValueError("Cohere client is not initialized.")
try:
......
......@@ -7,6 +7,14 @@ from semantic_router.encoders import DenseEncoder
class FastEmbedEncoder(DenseEncoder):
"""Dense encoder that uses local FastEmbed to embed documents. Supports text only.
Requires the fastembed package which can be installed with `pip install 'semantic-router[fastembed]'`
:param name: The name of the embedding model to use.
:param max_length: The maximum length of the input text.
:param cache_dir: The directory to cache the embedding model.
:param threads: The number of threads to use for the embedding.
"""
type: str = "fastembed"
name: str = "BAAI/bge-small-en-v1.5"
max_length: int = 512
......@@ -16,11 +24,19 @@ class FastEmbedEncoder(DenseEncoder):
def __init__(
self, score_threshold: float = 0.5, **data
): # TODO default score_threshold not thoroughly tested, should optimize
):
"""Initialize the FastEmbed encoder.
:param score_threshold: The threshold for the score of the embedding.
:type score_threshold: float
"""
# TODO default score_threshold not thoroughly tested, should optimize
super().__init__(score_threshold=score_threshold, **data)
self._client = self._initialize_client()
def _initialize_client(self):
"""Initialize the FastEmbed library. Requires the fastembed package.
"""
try:
from fastembed import TextEmbedding
except ImportError:
......@@ -43,6 +59,14 @@ class FastEmbedEncoder(DenseEncoder):
return embedding
def __call__(self, docs: List[str]) -> List[List[float]]:
"""Embed a list of documents. Supports text only.
:param docs: The documents to embed.
:type docs: List[str]
:raise ValueError: If the embedding fails.
:return: The vector embeddings of the documents.
:rtype: List[List[float]]
"""
try:
embeds: List[np.ndarray] = list(self._client.embed(docs))
embeddings: List[List[float]] = [e.tolist() for e in embeds]
......
"""
This module provides the GoogleEncoder class for generating embeddings using Google's AI Platform.
The GoogleEncoder class is a subclass of DenseEncoder and utilizes the TextEmbeddingModel from the
Google AI Platform to generate embeddings for given documents. It requires a Google Cloud project ID
and supports customization of the pre-trained model, score threshold, location, and API endpoint.
Example usage:
from semantic_router.encoders.google_encoder import GoogleEncoder
encoder = GoogleEncoder(project_id="your-project-id")
embeddings = encoder(["document1", "document2"])
Classes:
GoogleEncoder: A class for generating embeddings using Google's AI Platform.
"""
import os
from typing import Any, List, Optional
......@@ -26,6 +8,19 @@ from semantic_router.utils.defaults import EncoderDefault
class GoogleEncoder(DenseEncoder):
"""GoogleEncoder class for generating embeddings using Google's AI Platform.
The GoogleEncoder class is a subclass of DenseEncoder and utilizes the TextEmbeddingModel from the
Google AI Platform to generate embeddings for given documents. It requires a Google Cloud project ID
and supports customization of the pre-trained model, score threshold, location, and API endpoint.
Example usage:
```python
from semantic_router.encoders.google_encoder import GoogleEncoder
encoder = GoogleEncoder(project_id="your-project-id")
embeddings = encoder(["document1", "document2"])
```
Attributes:
client: An instance of the TextEmbeddingModel client.
type: The type of the encoder, which is "google".
......@@ -44,23 +39,25 @@ class GoogleEncoder(DenseEncoder):
):
"""Initializes the GoogleEncoder.
Args:
model_name: The name of the pre-trained model to use for embedding.
If not provided, the default model specified in EncoderDefault will
be used.
score_threshold: The threshold for similarity scores.
project_id: The Google Cloud project ID.
If not provided, it will be retrieved from the GOOGLE_PROJECT_ID
environment variable.
location: The location of the AI Platform resources.
If not provided, it will be retrieved from the GOOGLE_LOCATION
environment variable, defaulting to "us-central1".
api_endpoint: The API endpoint for the AI Platform.
If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT
environment variable.
Raises:
ValueError: If the Google Project ID is not provided or if the AI Platform
:param model_name: The name of the pre-trained model to use for embedding.
If not provided, the default model specified in EncoderDefault will
be used.
:type model_name: str
:param score_threshold: The threshold for similarity scores.
:type score_threshold: float
:param project_id: The Google Cloud project ID.
If not provided, it will be retrieved from the GOOGLE_PROJECT_ID
environment variable.
:type project_id: str
:param location: The location of the AI Platform resources.
If not provided, it will be retrieved from the GOOGLE_LOCATION
environment variable, defaulting to "us-central1".
:type location: str
:param api_endpoint: The API endpoint for the AI Platform.
If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT
environment variable.
:type api_endpoint: str
:raise ValueError: If the Google Project ID is not provided or if the AI Platform
client fails to initialize.
"""
if name is None:
......@@ -73,18 +70,17 @@ class GoogleEncoder(DenseEncoder):
def _initialize_client(self, project_id, location, api_endpoint):
"""Initializes the Google AI Platform client.
Args:
project_id: The Google Cloud project ID.
location: The location of the AI Platform resources.
api_endpoint: The API endpoint for the AI Platform.
Returns:
An instance of the TextEmbeddingModel client.
Raises:
ImportError: If the required Google Cloud or Vertex AI libraries are not
:param project_id: The Google Cloud project ID.
:type project_id: str
:param location: The location of the AI Platform resources.
:type location: str
:param api_endpoint: The API endpoint for the AI Platform.
:type api_endpoint: str
:return: An instance of the TextEmbeddingModel client.
:rtype: TextEmbeddingModel
:raise ImportError: If the required Google Cloud or Vertex AI libraries are not
installed.
ValueError: If the Google Project ID is not provided or if the AI Platform
:raise ValueError: If the Google Project ID is not provided or if the AI Platform
client fails to initialize.
"""
try:
......@@ -119,15 +115,12 @@ class GoogleEncoder(DenseEncoder):
def __call__(self, docs: List[str]) -> List[List[float]]:
"""Generates embeddings for the given documents.
Args:
docs: A list of strings representing the documents to embed.
Returns:
A list of lists, where each inner list contains the embedding values for a
:param docs: A list of strings representing the documents to embed.
:type docs: List[str]
:return: A list of lists, where each inner list contains the embedding values for a
document.
Raises:
ValueError: If the Google AI Platform client is not initialized or if the
:rtype: List[List[float]]
:raise ValueError: If the Google AI Platform client is not initialized or if the
API call fails.
"""
if self.client is None:
......
......@@ -31,7 +31,26 @@ from semantic_router.encoders import DenseEncoder
from semantic_router.utils.logger import logger
# TODO: this should support local models, and we should have another class for remote
# inference endpoint models
class HuggingFaceEncoder(DenseEncoder):
"""HuggingFace encoder class for local embedding models. Models can be trained and
loaded from private repositories, or from the Huggingface Hub. The class supports
customization of the score threshold for filtering or processing the embeddings.
Example usage:
```python
from semantic_router.encoders import HuggingFaceEncoder
encoder = HuggingFaceEncoder(
name="sentence-transformers/all-MiniLM-L6-v2",
device="cuda"
)
embeddings = encoder(["document1", "document2"])
```
"""
name: str = "sentence-transformers/all-MiniLM-L6-v2"
type: str = "huggingface"
tokenizer_kwargs: Dict = {}
......@@ -92,6 +111,12 @@ class HuggingFaceEncoder(DenseEncoder):
normalize_embeddings: bool = True,
pooling_strategy: str = "mean",
) -> List[List[float]]:
"""Encode a list of documents into embeddings using the local Hugging Face model.
:param docs: A list of documents to encode.
:type docs: List[str]
:param batch_size: The batch size for encoding.
"""
all_embeddings = []
for i in range(0, len(docs), batch_size):
batch_docs = docs[i : i + batch_size]
......@@ -124,6 +149,12 @@ class HuggingFaceEncoder(DenseEncoder):
return all_embeddings
def _mean_pooling(self, model_output, attention_mask):
"""Perform mean pooling on the token embeddings.
:param model_output: The output of the model.
:type model_output: torch.Tensor
:param attention_mask: The attention mask.
"""
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
......@@ -133,6 +164,12 @@ class HuggingFaceEncoder(DenseEncoder):
) / self._torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def _max_pooling(self, model_output, attention_mask):
"""Perform max pooling on the token embeddings.
:param model_output: The output of the model.
:type model_output: torch.Tensor
:param attention_mask: The attention mask.
"""
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
......@@ -142,13 +179,24 @@ class HuggingFaceEncoder(DenseEncoder):
class HFEndpointEncoder(DenseEncoder):
"""
A class to encode documents using a Hugging Face transformer model endpoint.
"""HFEndpointEncoder class to embeddings models using Huggingface's inference endpoints.
The HFEndpointEncoder class is a subclass of DenseEncoder and utilizes a specified
Huggingface endpoint to generate embeddings for given documents. It requires the URL
of the Huggingface API endpoint and an API key for authentication. The class supports
customization of the score threshold for filtering or processing the embeddings.
Attributes:
huggingface_url (str): The URL of the Hugging Face API endpoint.
huggingface_api_key (str): The API key for authenticating with the Hugging Face API.
score_threshold (float): A threshold value used for filtering or processing the embeddings.
Example usage:
```python
from semantic_router.encoders import HFEndpointEncoder
encoder = HFEndpointEncoder(
huggingface_url="https://api-inference.huggingface.co/models/BAAI/bge-large-en-v1.5",
huggingface_api_key="your-hugging-face-api-key"
)
embeddings = encoder(["document1", "document2"])
```
"""
name: str = "hugging_face_custom_endpoint"
......@@ -162,21 +210,17 @@ class HFEndpointEncoder(DenseEncoder):
huggingface_api_key: Optional[str] = None,
score_threshold: float = 0.8,
):
"""
Initializes the HFEndpointEncoder with the specified parameters.
Args:
name (str, optional): The name of the encoder. Defaults to
"hugging_face_custom_endpoint".
huggingface_url (str, optional): The URL of the Hugging Face API endpoint.
Cannot be None.
huggingface_api_key (str, optional): The API key for the Hugging Face API.
Cannot be None.
score_threshold (float, optional): A threshold for processing the embeddings.
Defaults to 0.8.
Raises:
ValueError: If either `huggingface_url` or `huggingface_api_key` is None.
"""Initializes the HFEndpointEncoder with the specified parameters.
:param name: The name of the encoder.
:type name: str
:param huggingface_url: The URL of the Hugging Face API endpoint.
:type huggingface_url: str
:param huggingface_api_key: The API key for the Hugging Face API.
:type huggingface_api_key: str
:param score_threshold: A threshold for processing the embeddings.
:type score_threshold: float
:raise ValueError: If either `huggingface_url` or `huggingface_api_key` is None.
"""
huggingface_url = huggingface_url or os.getenv("HF_API_URL")
huggingface_api_key = huggingface_api_key or os.getenv("HF_API_KEY")
......@@ -201,17 +245,13 @@ class HFEndpointEncoder(DenseEncoder):
) from e
def __call__(self, docs: List[str]) -> List[List[float]]:
"""
Encodes a list of documents into embeddings using the Hugging Face API.
Args:
docs (List[str]): A list of documents to encode.
Returns:
List[List[float]]: A list of embeddings for the given documents.
"""Encodes a list of documents into embeddings using the Hugging Face API.
Raises:
ValueError: If no embeddings are returned for a document.
:param docs: A list of documents to encode.
:type docs: List[str]
:return: A list of embeddings for the given documents.
:rtype: List[List[float]]
:raise ValueError: If no embeddings are returned for a document.
"""
embeddings = []
for d in docs:
......@@ -228,17 +268,13 @@ class HFEndpointEncoder(DenseEncoder):
return embeddings
def query(self, payload, max_retries=3, retry_interval=5):
"""
Sends a query to the Hugging Face API and returns the response.
Args:
payload (dict): The payload to send in the request.
Returns:
dict: The response from the Hugging Face API.
"""Sends a query to the Hugging Face API and returns the response.
Raises:
ValueError: If the query fails or the response status is not 200.
:param payload: The payload to send in the request.
:type payload: dict
:return: The response from the Hugging Face API.
:rtype: dict
:raise ValueError: If the query fails or the response status is not 200.
"""
headers = {
"Accept": "application/json",
......
......@@ -11,7 +11,8 @@ from semantic_router.utils.defaults import EncoderDefault
class MistralEncoder(DenseEncoder):
"""Class to encode text using MistralAI"""
"""Class to encode text using MistralAI. Requires a MistralAI API key from
https://console.mistral.ai/api-keys/"""
_client: Any = PrivateAttr()
_mistralai: Any = PrivateAttr()
......@@ -23,12 +24,27 @@ class MistralEncoder(DenseEncoder):
mistralai_api_key: Optional[str] = None,
score_threshold: float = 0.82,
):
"""Initialize the MistralEncoder.
:param name: The name of the embedding model to use.
:type name: str
:param mistralai_api_key: The MistralAI API key.
:type mistralai_api_key: str
:param score_threshold: The score threshold for the embeddings.
"""
if name is None:
name = EncoderDefault.MISTRAL.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
self._client, self._mistralai = self._initialize_client(mistralai_api_key)
def _initialize_client(self, api_key):
"""Initialize the MistralAI client.
:param api_key: The MistralAI API key.
:type api_key: str
:return: The MistralAI client.
:rtype: MistralClient
"""
try:
import mistralai
from mistralai.client import MistralClient
......@@ -49,6 +65,13 @@ class MistralEncoder(DenseEncoder):
return client, mistralai
def __call__(self, docs: List[str]) -> List[List[float]]:
"""Encode a list of documents into embeddings using MistralAI.
:param docs: The documents to encode.
:type docs: List[str]
:return: The embeddings for the documents.
:rtype: List[List[float]]
"""
if self._client is None:
raise ValueError("Mistral client not initialized")
embeds = None
......
......@@ -35,6 +35,12 @@ model_configs = {
class OpenAIEncoder(DenseEncoder):
"""OpenAI encoder class for generating embeddings using OpenAI API.
The OpenAIEncoder class is a subclass of DenseEncoder and utilizes the OpenAI API
to generate embeddings for given documents. It requires an OpenAI API key and
supports customization of the score threshold for filtering or processing the embeddings.
"""
_client: Optional[openai.Client] = PrivateAttr(default=None)
_async_client: Optional[openai.AsyncClient] = PrivateAttr(default=None)
dimensions: Union[int, NotGiven] = NotGiven()
......@@ -53,6 +59,23 @@ class OpenAIEncoder(DenseEncoder):
dimensions: Union[int, NotGiven] = NotGiven(),
max_retries: int = 3,
):
"""Initialize the OpenAIEncoder.
:param name: The name of the embedding model to use.
:type name: str
:param openai_base_url: The base URL for the OpenAI API.
:type openai_base_url: str
:param openai_api_key: The OpenAI API key.
:type openai_api_key: str
:param openai_org_id: The OpenAI organization ID.
:type openai_org_id: str
:param score_threshold: The score threshold for the embeddings.
:type score_threshold: float
:param dimensions: The dimensions of the embeddings.
:type dimensions: int
:param max_retries: The maximum number of retries for the OpenAI API call.
:type max_retries: int
"""
if name is None:
name = EncoderDefault.OPENAI.value["embedding_model"]
if score_threshold is None and name in model_configs:
......@@ -146,6 +169,13 @@ class OpenAIEncoder(DenseEncoder):
return embeddings
def _truncate(self, text: str) -> str:
"""Truncate a document to the token limit.
:param text: The document to truncate.
:type text: str
:return: The truncated document.
:rtype: str
"""
# we use encode_ordinary as faster equivalent to encode(text, disallowed_special=())
tokens = self._token_encoder.encode_ordinary(text)
if len(tokens) > self.token_limit:
......@@ -159,6 +189,13 @@ class OpenAIEncoder(DenseEncoder):
return text
async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
"""Encode a list of text documents into embeddings using OpenAI API asynchronously.
:param docs: List of text documents to encode.
:param truncate: Whether to truncate the documents to token limit. If
False and a document exceeds the token limit, an error will be
raised.
:return: List of embeddings for each document."""
if self._async_client is None:
raise ValueError("OpenAI async client is not initialized.")
embeds = None
......
......@@ -6,6 +6,12 @@ from semantic_router.encoders import DenseEncoder
class VitEncoder(DenseEncoder):
"""Encoder for Vision Transformer models.
This class provides functionality to encode images using a Vision Transformer
model via Hugging Face. It supports various image processing and model initialization
options.
"""
name: str = "google/vit-base-patch16-224"
type: str = "huggingface"
processor_kwargs: Dict = {}
......@@ -18,12 +24,22 @@ class VitEncoder(DenseEncoder):
_Image: Any = PrivateAttr()
def __init__(self, **data):
"""Initialize the VitEncoder.
:param **data: Additional keyword arguments for the encoder.
:type **data: dict
"""
if data.get("score_threshold") is None:
data["score_threshold"] = 0.5
super().__init__(**data)
self._processor, self._model = self._initialize_hf_model()
def _initialize_hf_model(self):
"""Initialize the Hugging Face model.
:return: The processor and model.
:rtype: tuple
"""
try:
from transformers import ViTImageProcessor, ViTModel
except ImportError:
......@@ -68,6 +84,11 @@ class VitEncoder(DenseEncoder):
return processor, model
def _get_device(self) -> str:
"""Get the device to use for the model.
:return: The device to use for the model.
:rtype: str
"""
if self.device:
device = self.device
elif self._torch.cuda.is_available():
......@@ -79,12 +100,26 @@ class VitEncoder(DenseEncoder):
return device
def _process_images(self, images: List[Any]):
"""Process the images for the model.
:param images: The images to process.
:type images: List[Any]
:return: The processed images.
:rtype: Any
"""
rgb_images = [self._ensure_rgb(img) for img in images]
processed_images = self._processor(images=rgb_images, return_tensors="pt")
processed_images = processed_images.to(self.device)
return processed_images
def _ensure_rgb(self, img: Any):
"""Ensure the image is in RGB format.
:param img: The image to ensure is in RGB format.
:type img: Any
:return: The image in RGB format.
:rtype: Any
"""
rgbimg = self._Image.new("RGB", img.size)
rgbimg.paste(img)
return rgbimg
......@@ -94,6 +129,15 @@ class VitEncoder(DenseEncoder):
imgs: List[Any],
batch_size: int = 32,
) -> List[List[float]]:
"""Encode a list of images into embeddings using the Vision Transformer model.
:param imgs: The images to encode.
:type imgs: List[Any]
:param batch_size: The batch size for encoding.
:type batch_size: int
:return: The embeddings for the images.
:rtype: List[List[float]]
"""
all_embeddings = []
for i in range(0, len(imgs), batch_size):
batch_imgs = imgs[i : i + batch_size]
......
......@@ -14,6 +14,11 @@ from semantic_router.utils.logger import logger
class AzureOpenAIEncoder(DenseEncoder):
"""Encoder for Azure OpenAI API.
This class provides functionality to encode text documents using the Azure OpenAI API.
It supports customization of the score threshold for filtering or processing the embeddings.
"""
client: Optional[openai.AzureOpenAI] = None
async_client: Optional[openai.AsyncAzureOpenAI] = None
dimensions: Union[int, NotGiven] = NotGiven()
......@@ -36,6 +41,25 @@ class AzureOpenAIEncoder(DenseEncoder):
dimensions: Union[int, NotGiven] = NotGiven(),
max_retries: int = 3,
):
"""Initialize the AzureOpenAIEncoder.
:param api_key: The API key for the Azure OpenAI API.
:type api_key: str
:param deployment_name: The name of the deployment to use.
:type deployment_name: str
:param azure_endpoint: The endpoint for the Azure OpenAI API.
:type azure_endpoint: str
:param api_version: The version of the API to use.
:type api_version: str
:param model: The model to use.
:type model: str
:param score_threshold: The score threshold for the embeddings.
:type score_threshold: float
:param dimensions: The dimensions of the embeddings.
:type dimensions: int
:param max_retries: The maximum number of retries for the API call.
:type max_retries: int
"""
name = deployment_name
if name is None:
name = EncoderDefault.AZURE.value["embedding_model"]
......@@ -98,6 +122,13 @@ class AzureOpenAIEncoder(DenseEncoder):
) from e
def __call__(self, docs: List[str]) -> List[List[float]]:
"""Encode a list of documents into embeddings using the Azure OpenAI API.
:param docs: The documents to encode.
:type docs: List[str]
:return: The embeddings for the documents.
:rtype: List[List[float]]
"""
if self.client is None:
raise ValueError("Azure OpenAI client is not initialized.")
embeds = None
......@@ -136,10 +167,16 @@ class AzureOpenAIEncoder(DenseEncoder):
return embeddings
async def acall(self, docs: List[str]) -> List[List[float]]:
"""Encode a list of documents into embeddings using the Azure OpenAI API asynchronously.
:param docs: The documents to encode.
:type docs: List[str]
:return: The embeddings for the documents.
:rtype: List[List[float]]
"""
if self.async_client is None:
raise ValueError("Azure OpenAI async client is not initialized.")
embeds = None
# Exponential backoff
for j in range(self.max_retries + 1):
try:
......
......@@ -7,12 +7,10 @@ from numpy.linalg import norm
def similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:
"""Compute the similarity scores between a query vector and a set of vectors.
Args:
xq: A query vector (1d ndarray)
index: A set of vectors.
Returns:
The similarity between the query vector and the set of vectors.
:param xq: A query vector (1d ndarray)
:param index: A set of vectors.
:return: The similarity between the query vector and the set of vectors.
:rtype: np.ndarray
"""
index_norm = norm(index, axis=1)
......@@ -22,7 +20,13 @@ def similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:
def top_scores(sim: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
# get indices of top_k records
"""Get the top scores and indices from a similarity matrix.
:param sim: A similarity matrix.
:param top_k: The number of top scores to get.
:return: The top scores and indices.
:rtype: Tuple[np.ndarray, np.ndarray]
"""
top_k = min(top_k, sim.shape[0])
idx = np.argpartition(sim, -top_k)[-top_k:]
scores = sim[idx]
......
......@@ -8,6 +8,11 @@ from semantic_router.utils.logger import logger
class BaseLLM(BaseModel):
"""Base class for LLMs typically used by dynamic routes.
This class provides a base implementation for LLMs. It defines the common
configuration and methods for all LLM classes.
"""
name: str
temperature: Optional[float] = 0.0
max_tokens: Optional[int] = None
......@@ -16,15 +21,39 @@ class BaseLLM(BaseModel):
arbitrary_types_allowed = True
def __init__(self, name: str, **kwargs):
"""Initialize the BaseLLM.
:param name: The name of the LLM.
:type name: str
:param **kwargs: Additional keyword arguments for the LLM.
:type **kwargs: dict
"""
super().__init__(name=name, **kwargs)
def __call__(self, messages: List[Message]) -> Optional[str]:
"""Call the LLM.
Must be implemented by subclasses.
:param messages: The messages to pass to the LLM.
:type messages: List[Message]
:return: The response from the LLM.
:rtype: Optional[str]
"""
raise NotImplementedError("Subclasses must implement this method")
def _check_for_mandatory_inputs(
self, inputs: dict[str, Any], mandatory_params: List[str]
) -> bool:
"""Check for mandatory parameters in inputs"""
"""Check for mandatory parameters in inputs.
:param inputs: The inputs to check for mandatory parameters.
:type inputs: dict[str, Any]
:param mandatory_params: The mandatory parameters to check for.
:type mandatory_params: List[str]
:return: True if all mandatory parameters are present, False otherwise.
:rtype: bool
"""
for name in mandatory_params:
if name not in inputs:
logger.error(f"Mandatory input {name} missing from query")
......@@ -34,7 +63,15 @@ class BaseLLM(BaseModel):
def _check_for_extra_inputs(
self, inputs: dict[str, Any], all_params: List[str]
) -> bool:
"""Check for extra parameters not defined in the signature"""
"""Check for extra parameters not defined in the signature.
:param inputs: The inputs to check for extra parameters.
:type inputs: dict[str, Any]
:param all_params: The all parameters to check for.
:type all_params: List[str]
:return: True if all extra parameters are present, False otherwise.
:rtype: bool
"""
input_keys = set(inputs.keys())
param_keys = set(all_params)
if not input_keys.issubset(param_keys):
......@@ -49,7 +86,15 @@ class BaseLLM(BaseModel):
self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
) -> bool:
"""Determine if the functions chosen by the LLM exist within the function_schemas,
and if the input arguments are valid for those functions."""
and if the input arguments are valid for those functions.
:param inputs: The inputs to check for validity.
:type inputs: List[Dict[str, Any]]
:param function_schemas: The function schemas to check against.
:type function_schemas: List[Dict[str, Any]]
:return: True if the inputs are valid, False otherwise.
:rtype: bool
"""
try:
# Currently only supporting single functions for most LLMs in Dynamic Routes.
if len(inputs) != 1:
......@@ -72,7 +117,15 @@ class BaseLLM(BaseModel):
def _validate_single_function_inputs(
self, inputs: Dict[str, Any], function_schema: Dict[str, Any]
) -> bool:
"""Validate the extracted inputs against the function schema"""
"""Validate the extracted inputs against the function schema.
:param inputs: The inputs to validate.
:type inputs: Dict[str, Any]
:param function_schema: The function schema to validate against.
:type function_schema: Dict[str, Any]
:return: True if the inputs are valid, False otherwise.
:rtype: bool
"""
try:
# Extract parameter names and determine if they are optional
signature = function_schema["signature"]
......@@ -107,7 +160,13 @@ class BaseLLM(BaseModel):
return False
def _extract_parameter_info(self, signature: str) -> tuple[List[str], List[str]]:
"""Extract parameter names and types from the function signature."""
"""Extract parameter names and types from the function signature.
:param signature: The function signature to extract parameter names and types from.
:type signature: str
:return: A tuple of parameter names and types.
:rtype: tuple[List[str], List[str]]
"""
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
......@@ -118,6 +177,15 @@ class BaseLLM(BaseModel):
def extract_function_inputs(
self, query: str, function_schemas: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Extract the function inputs from the query.
:param query: The query to extract the function inputs from.
:type query: str
:param function_schemas: The function schemas to extract the function inputs from.
:type function_schemas: List[Dict[str, Any]]
:return: The function inputs.
:rtype: List[Dict[str, Any]]
"""
logger.info("Extracting function input...")
prompt = f"""
......
......@@ -8,6 +8,11 @@ from semantic_router.schema import Message
class CohereLLM(BaseLLM):
"""LLM for Cohere. Requires a Cohere API key from https://dashboard.cohere.com/api-keys.
This class provides functionality to interact with the Cohere API for generating text responses.
It extends the BaseLLM class and implements the __call__ method to generate text responses.
"""
_client: Any = PrivateAttr()
def __init__(
......@@ -15,12 +20,27 @@ class CohereLLM(BaseLLM):
name: Optional[str] = None,
cohere_api_key: Optional[str] = None,
):
"""Initialize the CohereLLM.
:param name: The name of the Cohere model to use can also be set via the
COHERE_CHAT_MODEL_NAME environment variable.
:type name: Optional[str]
:param cohere_api_key: The API key for the Cohere client. Can also be set via the
COHERE_API_KEY environment variable.
:type cohere_api_key: Optional[str]
"""
if name is None:
name = os.getenv("COHERE_CHAT_MODEL_NAME", "command")
super().__init__(name=name)
self._client = self._initialize_client(cohere_api_key)
def _initialize_client(self, cohere_api_key: Optional[str] = None):
"""Initialize the Cohere client.
:param cohere_api_key: The API key for the Cohere client. Can also be set via the
COHERE_API_KEY environment variable.
:type cohere_api_key: Optional[str]
"""
try:
import cohere
except ImportError:
......@@ -41,6 +61,13 @@ class CohereLLM(BaseLLM):
return client
def __call__(self, messages: List[Message]) -> str:
"""Call the Cohere client.
:param messages: The messages to pass to the Cohere client.
:type messages: List[Message]
:return: The response from the Cohere client.
:rtype: str
"""
if self._client is None:
raise ValueError("Cohere client is not initialized.")
try:
......
......@@ -10,6 +10,9 @@ from semantic_router.utils.logger import logger
class LlamaCppLLM(BaseLLM):
"""LLM for LlamaCPP. Enables fully local LLM use, helpful for local implementation of
dynamic routes.
"""
llm: Any
grammar: Optional[Any] = None
_llama_cpp: Any = PrivateAttr()
......@@ -22,6 +25,19 @@ class LlamaCppLLM(BaseLLM):
max_tokens: Optional[int] = 200,
grammar: Optional[Any] = None,
):
"""Initialize the LlamaCPPLLM.
:param llm: The LLM to use.
:type llm: Any
:param name: The name of the LLM.
:type name: str
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: Optional[int]
:param grammar: The grammar to use.
:type grammar: Optional[Any]
"""
super().__init__(
name=name,
llm=llm,
......@@ -48,6 +64,13 @@ class LlamaCppLLM(BaseLLM):
self,
messages: List[Message],
) -> str:
"""Call the LlamaCPPLLM.
:param messages: The messages to pass to the LlamaCPPLLM.
:type messages: List[Message]
:return: The response from the LlamaCPPLLM.
:rtype: str
"""
try:
completion = self.llm.create_chat_completion(
messages=[m.to_llamacpp() for m in messages],
......@@ -68,6 +91,11 @@ class LlamaCppLLM(BaseLLM):
@contextmanager
def _grammar(self):
"""Context manager for the grammar.
:return: The grammar.
:rtype: Any
"""
grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
try:
......@@ -79,6 +107,15 @@ class LlamaCppLLM(BaseLLM):
def extract_function_inputs(
self, query: str, function_schemas: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Extract the function inputs from the query.
:param query: The query to extract the function inputs from.
:type query: str
:param function_schemas: The function schemas to extract the function inputs from.
:type function_schemas: List[Dict[str, Any]]
:return: The function inputs.
:rtype: List[Dict[str, Any]]
"""
with self._grammar():
return super().extract_function_inputs(
query=query, function_schemas=function_schemas
......
......@@ -10,6 +10,8 @@ from semantic_router.utils.logger import logger
class MistralAILLM(BaseLLM):
"""LLM for MistralAI. Requires a MistralAI API key from https://console.mistral.ai/api-keys/
"""
_client: Any = PrivateAttr()
_mistralai: Any = PrivateAttr()
......@@ -20,6 +22,17 @@ class MistralAILLM(BaseLLM):
temperature: float = 0.01,
max_tokens: int = 200,
):
"""Initialize the MistralAILLM.
:param name: The name of the MistralAI model to use.
:type name: Optional[str]
:param mistralai_api_key: The MistralAI API key.
:type mistralai_api_key: Optional[str]
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: int
"""
if name is None:
name = EncoderDefault.MISTRAL.value["language_model"]
super().__init__(name=name)
......@@ -28,6 +41,13 @@ class MistralAILLM(BaseLLM):
self.max_tokens = max_tokens
def _initialize_client(self, api_key):
"""Initialize the MistralAI client.
:param api_key: The MistralAI API key.
:type api_key: Optional[str]
:return: The MistralAI client.
:rtype: MistralClient
"""
try:
import mistralai
from mistralai.client import MistralClient
......@@ -49,9 +69,15 @@ class MistralAILLM(BaseLLM):
return client, mistralai
def __call__(self, messages: List[Message]) -> str:
"""Call the MistralAILLM.
:param messages: The messages to pass to the MistralAILLM.
:type messages: List[Message]
:return: The response from the MistralAILLM.
:rtype: str
"""
if self._client is None:
raise ValueError("MistralAI client is not initialized.")
chat_messages = [
self._mistralai.models.chat_completion.ChatMessage(
role=m.role, content=m.content
......
......@@ -8,6 +8,9 @@ from semantic_router.utils.logger import logger
class OllamaLLM(BaseLLM):
"""LLM for Ollama. Enables fully local LLM use, helpful for local implementation of
dynamic routes.
"""
stream: bool = False
def __init__(
......@@ -17,6 +20,17 @@ class OllamaLLM(BaseLLM):
max_tokens: Optional[int] = 200,
stream: bool = False,
):
"""Initialize the OllamaLLM.
:param name: The name of the Ollama model to use.
:type name: str
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: Optional[int]
:param stream: Whether to stream the response.
:type stream: bool
"""
super().__init__(name=name)
self.temperature = temperature
self.max_tokens = max_tokens
......@@ -30,6 +44,19 @@ class OllamaLLM(BaseLLM):
max_tokens: Optional[int] = None,
stream: Optional[bool] = None,
) -> str:
"""Call the OllamaLLM.
:param messages: The messages to pass to the OllamaLLM.
:type messages: List[Message]
:param temperature: The temperature of the LLM.
:type temperature: Optional[float]
:param name: The name of the Ollama model to use.
:type name: Optional[str]
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: Optional[int]
:param stream: Whether to stream the response.
:type stream: Optional[bool]
"""
# Use instance defaults if not overridden
temperature = temperature if temperature is not None else self.temperature
name = name if name is not None else self.name
......
......@@ -22,6 +22,8 @@ from semantic_router.utils.logger import logger
class OpenAILLM(BaseLLM):
"""LLM for OpenAI. Requires an OpenAI API key from https://platform.openai.com/api-keys.
"""
_client: Optional[openai.OpenAI] = PrivateAttr(default=None)
_async_client: Optional[openai.AsyncOpenAI] = PrivateAttr(default=None)
......@@ -32,6 +34,17 @@ class OpenAILLM(BaseLLM):
temperature: float = 0.01,
max_tokens: int = 200,
):
"""Initialize the OpenAILLM.
:param name: The name of the OpenAI model to use.
:type name: Optional[str]
:param openai_api_key: The OpenAI API key.
:type openai_api_key: Optional[str]
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: int
"""
if name is None:
name = EncoderDefault.OPENAI.value["language_model"]
super().__init__(name=name)
......@@ -51,6 +64,13 @@ class OpenAILLM(BaseLLM):
def _extract_tool_calls_info(
self, tool_calls: List[ChatCompletionMessageToolCall]
) -> List[Dict[str, Any]]:
"""Extract the tool calls information from the tool calls.
:param tool_calls: The tool calls to extract the information from.
:type tool_calls: List[ChatCompletionMessageToolCall]
:return: The tool calls information.
:rtype: List[Dict[str, Any]]
"""
tool_calls_info = []
for tool_call in tool_calls:
if tool_call.function.arguments is None:
......@@ -68,6 +88,13 @@ class OpenAILLM(BaseLLM):
async def async_extract_tool_calls_info(
self, tool_calls: List[ChatCompletionMessageToolCall]
) -> List[Dict[str, Any]]:
"""Extract the tool calls information from the tool calls.
:param tool_calls: The tool calls to extract the information from.
:type tool_calls: List[ChatCompletionMessageToolCall]
:return: The tool calls information.
:rtype: List[Dict[str, Any]]
"""
tool_calls_info = []
for tool_call in tool_calls:
if tool_call.function.arguments is None:
......@@ -87,6 +114,15 @@ class OpenAILLM(BaseLLM):
messages: List[Message],
function_schemas: Optional[List[Dict[str, Any]]] = None,
) -> str:
"""Call the OpenAILLM.
:param messages: The messages to pass to the OpenAILLM.
:type messages: List[Message]
:param function_schemas: The function schemas to pass to the OpenAILLM.
:type function_schemas: Optional[List[Dict[str, Any]]]
:return: The response from the OpenAILLM.
:rtype: str
"""
if self._client is None:
raise ValueError("OpenAI client is not initialized.")
try:
......@@ -131,6 +167,15 @@ class OpenAILLM(BaseLLM):
messages: List[Message],
function_schemas: Optional[List[Dict[str, Any]]] = None,
) -> str:
"""Call the OpenAILLM asynchronously.
:param messages: The messages to pass to the OpenAILLM.
:type messages: List[Message]
:param function_schemas: The function schemas to pass to the OpenAILLM.
:type function_schemas: Optional[List[Dict[str, Any]]]
:return: The response from the OpenAILLM.
:rtype: str
"""
if self._async_client is None:
raise ValueError("OpenAI async_client is not initialized.")
try:
......@@ -173,6 +218,15 @@ class OpenAILLM(BaseLLM):
def extract_function_inputs(
self, query: str, function_schemas: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Extract the function inputs from the query.
:param query: The query to extract the function inputs from.
:type query: str
:param function_schemas: The function schemas to extract the function inputs from.
:type function_schemas: List[Dict[str, Any]]
:return: The function inputs.
:rtype: List[Dict[str, Any]]
"""
system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
messages = [
Message(role="system", content=system_prompt),
......@@ -190,6 +244,15 @@ class OpenAILLM(BaseLLM):
async def async_extract_function_inputs(
self, query: str, function_schemas: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Extract the function inputs from the query asynchronously.
:param query: The query to extract the function inputs from.
:type query: str
:param function_schemas: The function schemas to extract the function inputs from.
:type function_schemas: List[Dict[str, Any]]
:return: The function inputs.
:rtype: List[Dict[str, Any]]
"""
system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
messages = [
Message(role="system", content=system_prompt),
......@@ -208,7 +271,15 @@ class OpenAILLM(BaseLLM):
self, inputs: List[Dict[str, Any]], function_schemas: List[Dict[str, Any]]
) -> bool:
"""Determine if the functions chosen by the LLM exist within the function_schemas,
and if the input arguments are valid for those functions."""
and if the input arguments are valid for those functions.
:param inputs: The inputs to check for validity.
:type inputs: List[Dict[str, Any]]
:param function_schemas: The function schemas to check against.
:type function_schemas: List[Dict[str, Any]]
:return: True if the inputs are valid, False otherwise.
:rtype: bool
"""
try:
for input_dict in inputs:
# Check if 'function_name' and 'arguments' keys exist in each input dictionary
......@@ -251,7 +322,14 @@ class OpenAILLM(BaseLLM):
def _validate_single_function_inputs(
self, inputs: Dict[str, Any], function_schema: Dict[str, Any]
) -> bool:
"""Validate the extracted inputs against the function schema"""
"""Validate the extracted inputs against the function schema.
:param inputs: The inputs to validate.
:type inputs: Dict[str, Any]
:param function_schema: The function schema to validate against.
:type function_schema: Dict[str, Any]
:return: True if the inputs are valid, False otherwise.
"""
try:
# Access the parameters and their properties from the function schema directly
parameters = function_schema["parameters"]["properties"]
......@@ -283,6 +361,13 @@ class OpenAILLM(BaseLLM):
def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]:
"""Get function schemas for the OpenAI LLM from a list of functions.
:param items: The functions to get function schemas for.
:type items: List[Callable]
:return: The schemas for the OpenAI LLM.
:rtype: List[Dict[str, Any]]
"""
schemas = []
for item in items:
if not callable(item):
......
......@@ -10,6 +10,8 @@ from semantic_router.utils.logger import logger
class OpenRouterLLM(BaseLLM):
"""LLM for OpenRouter. Requires an OpenRouter API key, see here for more information
https://openrouter.ai/docs/api-reference/authentication#using-an-api-key"""
_client: Optional[openai.OpenAI] = PrivateAttr(default=None)
_base_url: str = PrivateAttr(default="https://openrouter.ai/api/v1")
......@@ -21,6 +23,19 @@ class OpenRouterLLM(BaseLLM):
temperature: float = 0.01,
max_tokens: int = 200,
):
"""Initialize the OpenRouterLLM.
:param name: The name of the OpenRouter model to use.
:type name: Optional[str]
:param openrouter_api_key: The OpenRouter API key.
:type openrouter_api_key: Optional[str]
:param base_url: The base URL for the OpenRouter API.
:type base_url: str
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: int
"""
if name is None:
name = os.getenv(
"OPENROUTER_CHAT_MODEL_NAME", "mistralai/mistral-7b-instruct"
......@@ -40,6 +55,13 @@ class OpenRouterLLM(BaseLLM):
self.max_tokens = max_tokens
def __call__(self, messages: List[Message]) -> str:
"""Call the OpenRouterLLM.
:param messages: The messages to pass to the OpenRouterLLM.
:type messages: List[Message]
:return: The response from the OpenRouterLLM.
:rtype: str
"""
if self._client is None:
raise ValueError("OpenRouter client is not initialized.")
try:
......
......@@ -11,6 +11,8 @@ from semantic_router.utils.logger import logger
class AzureOpenAILLM(BaseLLM):
"""LLM for Azure OpenAI. Requires an Azure OpenAI API key.
"""
_client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None)
def __init__(
......@@ -22,6 +24,21 @@ class AzureOpenAILLM(BaseLLM):
max_tokens: int = 200,
api_version="2023-07-01-preview",
):
"""Initialize the AzureOpenAILLM.
:param name: The name of the Azure OpenAI model to use.
:type name: Optional[str]
:param openai_api_key: The Azure OpenAI API key.
:type openai_api_key: Optional[str]
:param azure_endpoint: The Azure OpenAI endpoint.
:type azure_endpoint: Optional[str]
:param temperature: The temperature of the LLM.
:type temperature: float
:param max_tokens: The maximum number of tokens to generate.
:type max_tokens: int
:param api_version: The API version to use.
:type api_version: str
"""
if name is None:
name = EncoderDefault.AZURE.value["language_model"]
super().__init__(name=name)
......@@ -41,6 +58,13 @@ class AzureOpenAILLM(BaseLLM):
self.max_tokens = max_tokens
def __call__(self, messages: List[Message]) -> str:
"""Call the AzureOpenAILLM.
:param messages: The messages to pass to the AzureOpenAILLM.
:type messages: List[Message]
:return: The response from the AzureOpenAILLM.
:rtype: str
"""
if self._client is None:
raise ValueError("AzureOpenAI client is not initialized.")
try:
......
......@@ -11,6 +11,13 @@ from semantic_router.utils.logger import logger
def is_valid(route_config: str) -> bool:
"""Check if the route config is valid.
:param route_config: The route config to check.
:type route_config: str
:return: Whether the route config is valid.
:rtype: bool
"""
try:
output_json = json.loads(route_config)
required_keys = ["name", "utterances"]
......@@ -39,18 +46,36 @@ def is_valid(route_config: str) -> bool:
class Route(BaseModel):
name: str
utterances: Union[List[str], List[Any]]
description: Optional[str] = None
function_schemas: Optional[List[Dict[str, Any]]] = None
llm: Optional[BaseLLM] = None
score_threshold: Optional[float] = None
metadata: Optional[Dict[str, Any]] = {}
"""A route for the semantic router.
:param name: The name of the route.
:type name: str
:param utterances: The utterances of the route.
:type utterances: Union[List[str], List[Any]]
:param description: The description of the route.
:type description: Optional[str]
:param function_schemas: The function schemas of the route.
:type function_schemas: Optional[List[Dict[str, Any]]]
:param llm: The LLM to use.
:type llm: Optional[BaseLLM]
:param score_threshold: The score threshold of the route.
:type score_threshold: Optional[float]
:param metadata: The metadata of the route.
:type metadata: Optional[Dict[str, Any]]
"""
class Config:
arbitrary_types_allowed = True
def __call__(self, query: Optional[str] = None) -> RouteChoice:
"""Call the route. If dynamic routes have been provided the query must have been
provided and the llm attribute must be set.
:param query: The query to pass to the route.
:type query: Optional[str]
:return: The route choice.
:rtype: RouteChoice
"""
if self.function_schemas:
if not self.llm:
raise ValueError(
......@@ -73,6 +98,14 @@ class Route(BaseModel):
return RouteChoice(name=self.name, function_call=func_call)
async def acall(self, query: Optional[str] = None) -> RouteChoice:
"""Asynchronous call the route. If dynamic routes have been provided the query
must have been provided and the llm attribute must be set.
:param query: The query to pass to the route.
:type query: Optional[str]
:return: The route choice.
:rtype: RouteChoice
"""
if self.function_schemas:
if not self.llm:
raise ValueError(
......@@ -95,6 +128,11 @@ class Route(BaseModel):
return RouteChoice(name=self.name, function_call=func_call)
def to_dict(self) -> Dict[str, Any]:
"""Convert the route to a dictionary.
:return: The dictionary representation of the route.
:rtype: Dict[str, Any]
"""
data = self.dict()
if self.llm is not None:
data["llm"] = {
......@@ -106,14 +144,27 @@ class Route(BaseModel):
@classmethod
def from_dict(cls, data: Dict[str, Any]):
"""Create a Route object from a dictionary.
:param data: The dictionary to create the route from.
:type data: Dict[str, Any]
:return: The created route.
:rtype: Route
"""
return cls(**data)
@classmethod
def from_dynamic_route(
cls, llm: BaseLLM, entities: List[Union[BaseModel, Callable]], route_name: str
):
"""
Generate a dynamic Route object from a list of functions or Pydantic models using LLM
"""Generate a dynamic Route object from a list of functions or Pydantic models
using an LLM.
:param llm: The LLM to use.
:type llm: BaseLLM
:param entities: The entities to use.
:type entities: List[Union[BaseModel, Callable]]
:param route_name: The name of the route.
"""
schemas = function_call.get_schema_list(items=entities)
dynamic_route = cls._generate_dynamic_route(
......@@ -124,6 +175,14 @@ class Route(BaseModel):
@classmethod
def _parse_route_config(cls, config: str) -> str:
"""Parse the route config from the LLM output using regex. Expects the output
content to be wrapped in <config></config> tags.
:param config: The LLM output.
:type config: str
:return: The parsed route config.
:rtype: str
"""
# Regular expression to match content inside <config></config>
config_pattern = r"<config>(.*?)</config>"
match = re.search(config_pattern, config, re.DOTALL)
......@@ -138,8 +197,14 @@ class Route(BaseModel):
def _generate_dynamic_route(
cls, llm: BaseLLM, function_schemas: List[Dict[str, Any]], route_name: str
):
logger.info("Generating dynamic route...")
"""Generate a dynamic Route object from a list of function schemas using an LLM.
:param llm: The LLM to use.
:type llm: BaseLLM
:param function_schemas: The function schemas to use.
:type function_schemas: List[Dict[str, Any]]
:param route_name: The name of the route.
"""
formatted_schemas = "\n".join(
[json.dumps(schema, indent=4) for schema in function_schemas]
)
......@@ -176,8 +241,6 @@ class Route(BaseModel):
route_config = cls._parse_route_config(config=output)
logger.info(f"Generated route config:\n{route_config}")
if is_valid(route_config):
route_config_dict = json.loads(route_config)
route_config_dict["llm"] = llm
......
......@@ -12,6 +12,9 @@ from aurelio_sdk.schema import SparseEmbedding as BM25SparseEmbedding
class EncoderType(Enum):
"""The type of encoder.
"""
AURELIO = "aurelio"
AZURE = "azure"
COHERE = "cohere"
......@@ -28,40 +31,59 @@ class EncoderType(Enum):
class EncoderInfo(BaseModel):
"""Information about an encoder.
"""
name: str
token_limit: int
threshold: Optional[float] = None
class RouteChoice(BaseModel):
"""A route choice typically output by the routers.
"""
name: Optional[str] = None
function_call: Optional[List[Dict]] = None
similarity_score: Optional[float] = None
class Message(BaseModel):
"""A message in a conversation, includes the role and content fields.
"""
role: str
content: str
def to_openai(self):
if self.role.lower() not in ["user", "assistant", "system"]:
raise ValueError("Role must be either 'user', 'assistant' or 'system'")
"""Convert the message to an OpenAI-compatible format.
"""
if self.role.lower() not in ["user", "assistant", "system", "tool"]:
raise ValueError("Role must be either 'user', 'assistant', 'system' or 'tool'")
return {"role": self.role, "content": self.content}
def to_cohere(self):
"""Convert the message to a Cohere-compatible format.
"""
return {"role": self.role, "message": self.content}
def to_llamacpp(self):
"""Convert the message to a LlamaCPP-compatible format.
"""
return {"role": self.role, "content": self.content}
def to_mistral(self):
"""Convert the message to a Mistral-compatible format.
"""
return {"role": self.role, "content": self.content}
def __str__(self):
"""Convert the message to a string.
"""
return f"{self.role}: {self.content}"
class ConfigParameter(BaseModel):
"""A configuration parameter for a route. Used for remote router metadata such as
router hashes, sync locks, etc.
"""
field: str
value: str
scope: Optional[str] = None
......@@ -70,6 +92,15 @@ class ConfigParameter(BaseModel):
)
def to_pinecone(self, dimensions: int):
"""Convert the configuration parameter to a Pinecone-compatible format. Should
be used when upserting configuration parameters to a separate config namespace
within your Pinecone index.
:param dimensions: The dimensions of the Pinecone index.
:type dimensions: int
:return: A Pinecone-compatible configuration parameter.
:rtype: dict
"""
namespace = self.scope or ""
return {
"id": f"{self.field}#{namespace}",
......@@ -84,6 +115,9 @@ class ConfigParameter(BaseModel):
class Utterance(BaseModel):
"""An utterance in a conversation, includes the route, utterance, function
schemas, metadata, and diff tag.
"""
route: str
utterance: Union[str, Any]
function_schemas: Optional[List[Dict]] = None
......@@ -129,6 +163,14 @@ class Utterance(BaseModel):
)
def to_str(self, include_metadata: bool = False):
"""Convert an Utterance object to a string. Used for comparisons during sync
check operations.
:param include_metadata: Whether to include metadata in the string.
:type include_metadata: bool
:return: A string representation of the Utterance object.
:rtype: str
"""
if include_metadata:
# we sort the dicts to ensure consistent order as we need this to compare
# stringified function schemas accurately
......@@ -149,8 +191,7 @@ class Utterance(BaseModel):
class SyncMode(Enum):
"""Synchronization modes for local (route layer) and remote (index)
instances.
"""Synchronization modes for local (route layer) and remote (index) instances.
"""
ERROR = "error"
......@@ -165,12 +206,22 @@ SYNC_MODES = [x.value for x in SyncMode]
class UtteranceDiff(BaseModel):
"""A list of Utterance objects that represent the differences between local and
remote utterances.
"""
diff: List[Utterance]
@classmethod
def from_utterances(
cls, local_utterances: List[Utterance], remote_utterances: List[Utterance]
):
"""Create a UtteranceDiff object from two lists of Utterance objects.
:param local_utterances: A list of Utterance objects.
:type local_utterances: List[Utterance]
:param remote_utterances: A list of Utterance objects.
:type remote_utterances: List[Utterance]
"""
local_utterances_map = {
x.to_str(include_metadata=True): x for x in local_utterances
}
......@@ -222,14 +273,18 @@ class UtteranceDiff(BaseModel):
This diff tells us that the remote has "route2: utterance3" and
"route2: utterance4", which do not exist locally.
:param include_metadata: Whether to include metadata in the string.
:type include_metadata: bool
:return: A list of diff strings.
:rtype: List[str]
"""
return [x.to_diff_str(include_metadata=include_metadata) for x in self.diff]
def get_tag(self, diff_tag: str) -> List[Utterance]:
"""Get all utterances with a given diff tag.
:param diff_tag: The diff tag to filter by. Must be one of "+", "-", or
" ".
:param diff_tag: The diff tag to filter by. Must be one of "+", "-", or " ".
:type diff_tag: str
:return: A list of Utterance objects.
:rtype: List[Utterance]
......@@ -239,8 +294,7 @@ class UtteranceDiff(BaseModel):
return [x for x in self.diff if x.diff_tag == diff_tag]
def get_sync_strategy(self, sync_mode: str) -> dict:
"""Generates the optimal synchronization plan for local and remote
instances.
"""Generates the optimal synchronization plan for local and remote instances.
:param sync_mode: The mode to sync the routes with the remote index.
:type sync_mode: str
......@@ -417,6 +471,8 @@ class UtteranceDiff(BaseModel):
class Metric(Enum):
"""The metric to use in vector-based similarity search indexes.
"""
COSINE = "cosine"
DOTPRODUCT = "dotproduct"
EUCLIDEAN = "euclidean"
......@@ -435,6 +491,13 @@ class SparseEmbedding(BaseModel):
@classmethod
def from_compact_array(cls, array: np.ndarray):
"""Create a SparseEmbedding object from a compact array.
:param array: A compact array.
:type array: np.ndarray
:return: A SparseEmbedding object.
:rtype: SparseEmbedding
"""
if array.ndim != 2 or array.shape[1] != 2:
raise ValueError(
f"Expected a 2D array with 2 columns, got a {array.ndim}D array with {array.shape[1]} columns. "
......@@ -444,32 +507,69 @@ class SparseEmbedding(BaseModel):
@classmethod
def from_vector(cls, vector: np.ndarray):
"""Consumes an array of sparse vectors containing zero-values."""
"""Consumes an array of sparse vectors containing zero-values.
:param vector: A sparse vector.
:type vector: np.ndarray
:return: A SparseEmbedding object.
:rtype: SparseEmbedding
"""
if vector.ndim != 1:
raise ValueError(f"Expected a 1D array, got a {vector.ndim}D array.")
return cls.from_compact_array(np.array([np.arange(len(vector)), vector]).T)
@classmethod
def from_aurelio(cls, embedding: BM25SparseEmbedding):
"""Create a SparseEmbedding object from an AurelioSparseEmbedding object.
:param embedding: An AurelioSparseEmbedding object.
:type embedding: BM25SparseEmbedding
:return: A SparseEmbedding object.
:rtype: SparseEmbedding
"""
arr = np.array([embedding.indices, embedding.values]).T
return cls.from_compact_array(arr)
@classmethod
def from_dict(cls, sparse_dict: dict):
"""Create a SparseEmbedding object from a dictionary.
:param sparse_dict: A dictionary of sparse values.
:type sparse_dict: dict
:return: A SparseEmbedding object.
:rtype: SparseEmbedding
"""
arr = np.array([list(sparse_dict.keys()), list(sparse_dict.values())]).T
return cls.from_compact_array(arr)
@classmethod
def from_pinecone_dict(cls, sparse_dict: dict):
"""Create a SparseEmbedding object from a Pinecone dictionary.
:param sparse_dict: A Pinecone dictionary.
:type sparse_dict: dict
:return: A SparseEmbedding object.
:rtype: SparseEmbedding
"""
arr = np.array([sparse_dict["indices"], sparse_dict["values"]]).T
return cls.from_compact_array(arr)
def to_dict(self):
"""Convert a SparseEmbedding object to a dictionary.
:return: A dictionary of sparse values.
:rtype: dict
"""
return {
i: v for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1])
}
def to_pinecone(self):
"""Convert a SparseEmbedding object to a Pinecone dictionary.
:return: A Pinecone dictionary.
:rtype: dict
"""
return {
"indices": self.embedding[:, 0].astype(int).tolist(),
"values": self.embedding[:, 1].tolist(),
......@@ -477,6 +577,11 @@ class SparseEmbedding(BaseModel):
# dictionary interface
def items(self):
"""Return a list of (index, value) tuples from the SparseEmbedding object.
:return: A list of (index, value) tuples.
:rtype: list
"""
return [
(i, v)
for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1])
......
......@@ -3,6 +3,9 @@ from enum import Enum
class EncoderDefault(Enum):
"""Default model names for each encoder type.
"""
FASTEMBED = {
"embedding_model": "BAAI/bge-small-en-v1.5",
"language_model": "BAAI/bge-small-en-v1.5",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment