Skip to content
Snippets Groups Projects
schema.py 20.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • James Briggs's avatar
    James Briggs committed
    from datetime import datetime, timezone
    
    from difflib import Differ
    
    James Briggs's avatar
    James Briggs committed
    from enum import Enum
    
    from typing import Any, Dict, List, Optional, Tuple, Union
    
    
    import numpy as np
    
    from aurelio_sdk.schema import SparseEmbedding as BM25SparseEmbedding
    
    from pydantic import BaseModel, Field
    
    from semantic_router.utils.logger import logger
    
    James Briggs's avatar
    James Briggs committed
    class EncoderType(Enum):
    
    James Briggs's avatar
    James Briggs committed
        """The type of encoder."""
    
        AURELIO = "aurelio"
    
        AZURE = "azure"
    
    James Briggs's avatar
    James Briggs committed
        COHERE = "cohere"
    
        OPENAI = "openai"
        BM25 = "bm25"
        TFIDF = "tfidf"
        FASTEMBED = "fastembed"
        HUGGINGFACE = "huggingface"
    
        MISTRAL = "mistral"
    
        VIT = "vit"
        CLIP = "clip"
    
    Kurtis Massey's avatar
    Kurtis Massey committed
        BEDROCK = "bedrock"
    
    James Briggs's avatar
    James Briggs committed
    
    
    Simonas's avatar
    Simonas committed
    
    
    class EncoderInfo(BaseModel):
    
    James Briggs's avatar
    James Briggs committed
        """Information about an encoder."""
    
    
        name: str
        token_limit: int
    
        threshold: Optional[float] = None
    
    James Briggs's avatar
    James Briggs committed
        """A route choice typically output by the routers."""
    
    
        name: Optional[str] = None
    
        function_call: Optional[List[Dict]] = None
    
    Simonas's avatar
    Simonas committed
        similarity_score: Optional[float] = None
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
    class Message(BaseModel):
    
    James Briggs's avatar
    James Briggs committed
        """A message in a conversation, includes the role and content fields."""
    
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
        role: str
        content: str
    
    
        def to_openai(self):
    
    James Briggs's avatar
    James Briggs committed
            """Convert the message to an OpenAI-compatible format."""
    
            if self.role.lower() not in ["user", "assistant", "system", "tool"]:
    
    James Briggs's avatar
    James Briggs committed
                raise ValueError(
                    "Role must be either 'user', 'assistant', 'system' or 'tool'"
                )
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
            return {"role": self.role, "content": self.content}
    
    
        def to_cohere(self):
    
    James Briggs's avatar
    James Briggs committed
            """Convert the message to a Cohere-compatible format."""
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
            return {"role": self.role, "message": self.content}
    
    
    Bogdan Buduroiu's avatar
    Bogdan Buduroiu committed
        def to_llamacpp(self):
    
    James Briggs's avatar
    James Briggs committed
            """Convert the message to a LlamaCPP-compatible format."""
    
    Bogdan Buduroiu's avatar
    Bogdan Buduroiu committed
            return {"role": self.role, "content": self.content}
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
    
    
        def to_mistral(self):
    
    James Briggs's avatar
    James Briggs committed
            """Convert the message to a Mistral-compatible format."""
    
            return {"role": self.role, "content": self.content}
    
    
        def __str__(self):
    
    James Briggs's avatar
    James Briggs committed
            """Convert the message to a string."""
    
            return f"{self.role}: {self.content}"
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
    
    
    James Briggs's avatar
    James Briggs committed
    class ConfigParameter(BaseModel):
    
        """A configuration parameter for a route. Used for remote router metadata such as
        router hashes, sync locks, etc.
        """
    
    James Briggs's avatar
    James Briggs committed
    
    
    James Briggs's avatar
    James Briggs committed
        field: str
        value: str
    
    James Briggs's avatar
    James Briggs committed
        scope: Optional[str] = None
    
    James Briggs's avatar
    James Briggs committed
        created_at: str = Field(
    
    James Briggs's avatar
    James Briggs committed
            default_factory=lambda: datetime.now(timezone.utc).isoformat()
    
    James Briggs's avatar
    James Briggs committed
        )
    
    James Briggs's avatar
    James Briggs committed
    
        def to_pinecone(self, dimensions: int):
    
            """Convert the configuration parameter to a Pinecone-compatible format. Should
            be used when upserting configuration parameters to a separate config namespace
            within your Pinecone index.
    
            :param dimensions: The dimensions of the Pinecone index.
            :type dimensions: int
            :return: A Pinecone-compatible configuration parameter.
            :rtype: dict
            """
    
    James Briggs's avatar
    James Briggs committed
            namespace = self.scope or ""
    
    James Briggs's avatar
    James Briggs committed
            return {
    
    James Briggs's avatar
    James Briggs committed
                "id": f"{self.field}#{namespace}",
    
    James Briggs's avatar
    James Briggs committed
                "values": [0.1] * dimensions,
                "metadata": {
                    "value": self.value,
                    "created_at": self.created_at,
    
    James Briggs's avatar
    James Briggs committed
                    "namespace": namespace,
    
    James Briggs's avatar
    James Briggs committed
                    "field": self.field,
                },
            }
    
    James Briggs's avatar
    James Briggs committed
    
    
    class Utterance(BaseModel):
    
        """An utterance in a conversation, includes the route, utterance, function
        schemas, metadata, and diff tag.
        """
    
    James Briggs's avatar
    James Briggs committed
    
    
    James Briggs's avatar
    James Briggs committed
        utterance: Union[str, Any]
    
        function_schemas: Optional[List[Dict]] = None
    
        metadata: dict = {}
    
        diff_tag: str = " "
    
        @classmethod
        def from_tuple(cls, tuple_obj: Tuple):
            """Create an Utterance object from a tuple. The tuple must contain
            route and utterance as the first two elements. Then optionally
            function schemas and metadata as the third and fourth elements
            respectively. If this order is not followed an invalid Utterance
            object will be returned.
    
            :param tuple_obj: A tuple containing route, utterance, function schemas and metadata.
            :type tuple_obj: Tuple
            :return: An Utterance object.
            :rtype: Utterance
            """
            route, utterance = tuple_obj[0], tuple_obj[1]
            function_schemas = tuple_obj[2] if len(tuple_obj) > 2 else None
    
    jamescalam's avatar
    jamescalam committed
            if isinstance(function_schemas, dict):
                function_schemas = [function_schemas]
    
            metadata = tuple_obj[3] if len(tuple_obj) > 3 else {}
    
            return cls(
                route=route,
                utterance=utterance,
                function_schemas=function_schemas,
    
    James Briggs's avatar
    James Briggs committed
                metadata=metadata,
    
    James Briggs's avatar
    James Briggs committed
            """Convert an Utterance object to a tuple.
    
            :return: A tuple containing (route, utterance, function schemas, metadata).
            :rtype: Tuple
            """
    
            return (
                self.route,
                self.utterance,
                self.function_schemas,
                self.metadata,
            )
    
        def to_str(self, include_metadata: bool = False):
    
            """Convert an Utterance object to a string. Used for comparisons during sync
            check operations.
    
            :param include_metadata: Whether to include metadata in the string.
            :type include_metadata: bool
            :return: A string representation of the Utterance object.
            :rtype: str
            """
    
            if include_metadata:
    
                # we sort the dicts to ensure consistent order as we need this to compare
                # stringified function schemas accurately
    
    jamescalam's avatar
    jamescalam committed
                if self.function_schemas is not None:
    
    jamescalam's avatar
    jamescalam committed
                    function_schemas_sorted: List[str] | None = [
    
    jamescalam's avatar
    jamescalam committed
                        json.dumps(schema, sort_keys=True)
                        for schema in self.function_schemas
                    ]
    
                else:
    
    jamescalam's avatar
    jamescalam committed
                    function_schemas_sorted = None
    
                # we must do the same for metadata
                metadata_sorted = json.dumps(self.metadata, sort_keys=True)
                return f"{self.route}: {self.utterance} | {function_schemas_sorted} | {metadata_sorted}"
    
            return f"{self.route}: {self.utterance}"
    
    
        def to_diff_str(self, include_metadata: bool = False):
            return f"{self.diff_tag} {self.to_str(include_metadata=include_metadata)}"
    
    James Briggs's avatar
    James Briggs committed
        """Synchronization modes for local (route layer) and remote (index) instances."""
    
    James Briggs's avatar
    James Briggs committed
    
    
        ERROR = "error"
        REMOTE = "remote"
        LOCAL = "local"
        MERGE_FORCE_REMOTE = "merge-force-remote"
        MERGE_FORCE_LOCAL = "merge-force-local"
        MERGE = "merge"
    
    
    James Briggs's avatar
    James Briggs committed
    
    SYNC_MODES = [x.value for x in SyncMode]
    
    
    
    class UtteranceDiff(BaseModel):
    
        """A list of Utterance objects that represent the differences between local and
        remote utterances.
        """
    
    James Briggs's avatar
    James Briggs committed
    
    
        diff: List[Utterance]
    
        @classmethod
        def from_utterances(
    
    James Briggs's avatar
    James Briggs committed
            cls, local_utterances: List[Utterance], remote_utterances: List[Utterance]
    
            """Create a UtteranceDiff object from two lists of Utterance objects.
    
            :param local_utterances: A list of Utterance objects.
            :type local_utterances: List[Utterance]
            :param remote_utterances: A list of Utterance objects.
            :type remote_utterances: List[Utterance]
            """
    
    James Briggs's avatar
    James Briggs committed
            local_utterances_map = {
                x.to_str(include_metadata=True): x for x in local_utterances
            }
            remote_utterances_map = {
                x.to_str(include_metadata=True): x for x in remote_utterances
            }
    
            # sort local and remote utterances
            local_utterances_str = list(local_utterances_map.keys())
            local_utterances_str.sort()
            remote_utterances_str = list(remote_utterances_map.keys())
            remote_utterances_str.sort()
            # get diff
            differ = Differ()
            diff_obj = list(differ.compare(local_utterances_str, remote_utterances_str))
            # create UtteranceDiff list
            utterance_diffs = []
            for line in diff_obj:
                utterance_str = line[2:]
                utterance_diff_tag = line[0]
    
                if utterance_diff_tag == "?":
                    # this is a new line from diff string, we can ignore
                    continue
    
    James Briggs's avatar
    James Briggs committed
                utterance = (
                    remote_utterances_map[utterance_str]
                    if utterance_diff_tag == "+"
                    else local_utterances_map[utterance_str]
                )
    
                utterance.diff_tag = utterance_diff_tag
                utterance_diffs.append(utterance)
            return UtteranceDiff(diff=utterance_diffs)
    
    
        def to_utterance_str(self, include_metadata: bool = False) -> List[str]:
    
            """Outputs the utterance diff as a list of diff strings. Returns a list
            of strings showing what is different in the remote when compared to the
            local. For example:
    
            ["  route1: utterance1",
             "  route1: utterance2",
             "- route2: utterance3",
             "- route2: utterance4"]
    
            Tells us that the remote is missing "route2: utterance3" and "route2:
            utterance4", which do exist locally. If we see:
    
            ["  route1: utterance1",
             "  route1: utterance2",
             "+ route2: utterance3",
             "+ route2: utterance4"]
    
            This diff tells us that the remote has "route2: utterance3" and
            "route2: utterance4", which do not exist locally.
    
    
            :param include_metadata: Whether to include metadata in the string.
            :type include_metadata: bool
            :return: A list of diff strings.
            :rtype: List[str]
    
            return [x.to_diff_str(include_metadata=include_metadata) for x in self.diff]
    
    
        def get_tag(self, diff_tag: str) -> List[Utterance]:
            """Get all utterances with a given diff tag.
    
    
            :param diff_tag: The diff tag to filter by. Must be one of "+", "-", or " ".
    
            :type diff_tag: str
            :return: A list of Utterance objects.
            :rtype: List[Utterance]
            """
            if diff_tag not in ["+", "-", " "]:
                raise ValueError("diff_tag must be one of '+', '-', or ' '")
            return [x for x in self.diff if x.diff_tag == diff_tag]
    
        def get_sync_strategy(self, sync_mode: str) -> dict:
    
            """Generates the optimal synchronization plan for local and remote instances.
    
    
            :param sync_mode: The mode to sync the routes with the remote index.
            :type sync_mode: str
            :return: A dictionary describing the synchronization strategy.
            :rtype: dict
            """
    
    James Briggs's avatar
    James Briggs committed
            if sync_mode not in SYNC_MODES:
                raise ValueError(f"sync_mode must be one of {SYNC_MODES}")
    
            local_only = self.get_tag("-")
    
    James Briggs's avatar
    James Briggs committed
            local_only_mapper = {
                utt.route: (utt.function_schemas, utt.metadata) for utt in local_only
            }
    
            remote_only = self.get_tag("+")
    
    James Briggs's avatar
    James Briggs committed
            remote_only_mapper = {
                utt.route: (utt.function_schemas, utt.metadata) for utt in remote_only
            }
    
            local_and_remote = self.get_tag(" ")
            if sync_mode == "error":
                if len(local_only) > 0 or len(remote_only) > 0:
                    raise ValueError(
                        "There are utterances that exist in the local or remote "
                        "instance that do not exist in the other instance. Please "
                        "sync the routes before running this command."
                    )
                else:
                    return {
    
    James Briggs's avatar
    James Briggs committed
                        "remote": {"upsert": [], "delete": []},
                        "local": {"upsert": [], "delete": []},
    
                    }
            elif sync_mode == "local":
                return {
                    "remote": {
    
    James Briggs's avatar
    James Briggs committed
                        "upsert": local_only,  # + remote_updates,
                        "delete": remote_only,
    
    James Briggs's avatar
    James Briggs committed
                    "local": {"upsert": [], "delete": []},
    
                }
            elif sync_mode == "remote":
                return {
    
    James Briggs's avatar
    James Briggs committed
                    "remote": {"upsert": [], "delete": []},
                    "local": {"upsert": remote_only, "delete": local_only},
    
            elif sync_mode == "merge-force-local":  # merge-to-local merge-join-local
    
    James Briggs's avatar
    James Briggs committed
                # PRIORITIZE LOCAL
    
    James Briggs's avatar
    James Briggs committed
                # get set of route names that exist in local (we keep these if
                # they are in remote)
                local_route_names = set([utt.route for utt in local_only])
                # if we see route: utterance exists in local, we do not pull it in
    
    James Briggs's avatar
    James Briggs committed
                # from remote
    
    James Briggs's avatar
    James Briggs committed
                local_route_utt_strs = set([utt.to_str() for utt in local_only])
    
    James Briggs's avatar
    James Briggs committed
                # get remote utterances that are in local
    
    James Briggs's avatar
    James Briggs committed
                remote_to_keep = [
                    utt
                    for utt in remote_only
                    if (
                        utt.route in local_route_names
                        and utt.to_str() not in local_route_utt_strs
                    )
                ]
    
    James Briggs's avatar
    James Briggs committed
                # overwrite remote routes with local metadata and function schemas
                logger.info(f"local_only_mapper: {local_only_mapper}")
                remote_to_update = [
                    Utterance(
                        route=utt.route,
                        utterance=utt.utterance,
                        metadata=local_only_mapper[utt.route][1],
    
    James Briggs's avatar
    James Briggs committed
                        function_schemas=local_only_mapper[utt.route][0],
                    )
                    for utt in remote_only
                    if (
                        utt.route in local_only_mapper
                        and (
                            utt.metadata != local_only_mapper[utt.route][1]
                            or utt.function_schemas != local_only_mapper[utt.route][0]
    
    James Briggs's avatar
    James Briggs committed
                        )
                    )
                ]
                remote_to_keep = [
                    Utterance(
                        route=utt.route,
                        utterance=utt.utterance,
                        metadata=local_only_mapper[utt.route][1],
    
    James Briggs's avatar
    James Briggs committed
                        function_schemas=local_only_mapper[utt.route][0],
                    )
                    for utt in remote_to_keep
                    if utt.to_str() not in [x.to_str() for x in remote_to_update]
                ]
                # get remote utterances that are NOT in local
                remote_to_delete = [
                    utt for utt in remote_only if utt.route not in local_route_names
    
                return {
                    "remote": {
    
    James Briggs's avatar
    James Briggs committed
                        "upsert": local_only + remote_to_update,
    
    James Briggs's avatar
    James Briggs committed
                        "delete": remote_to_delete,
    
    James Briggs's avatar
    James Briggs committed
                    "local": {"upsert": remote_to_keep, "delete": []},
    
            elif sync_mode == "merge-force-remote":  # merge-to-remote merge-join-remote
    
    James Briggs's avatar
    James Briggs committed
                # get set of route names that exist in remote (we keep these if
                # they are in local)
                remote_route_names = set([utt.route for utt in remote_only])
                # if we see route: utterance exists in remote, we do not pull it in
    
    James Briggs's avatar
    James Briggs committed
                # from local
    
    James Briggs's avatar
    James Briggs committed
                remote_route_utt_strs = set([utt.to_str() for utt in remote_only])
                # get local utterances that are in remote
    
    James Briggs's avatar
    James Briggs committed
                local_to_keep = [
                    utt
                    for utt in local_only
                    if (
                        utt.route in remote_route_names
                        and utt.to_str() not in remote_route_utt_strs
                    )
                ]
    
    James Briggs's avatar
    James Briggs committed
                # overwrite remote routes with local metadata and function schemas
                local_to_keep = [
                    Utterance(
                        route=utt.route,
                        utterance=utt.utterance,
                        metadata=remote_only_mapper[utt.route][1],
    
    James Briggs's avatar
    James Briggs committed
                        function_schemas=remote_only_mapper[utt.route][0],
                    )
                    for utt in local_to_keep
    
    James Briggs's avatar
    James Briggs committed
                # get local utterances that are NOT in remote
    
    James Briggs's avatar
    James Briggs committed
                local_to_delete = [
                    utt for utt in local_only if utt.route not in remote_route_names
                ]
    
    James Briggs's avatar
    James Briggs committed
                    "remote": {"upsert": local_to_keep, "delete": []},
                    "local": {"upsert": remote_only, "delete": local_to_delete},
    
                }
            elif sync_mode == "merge":
    
    James Briggs's avatar
    James Briggs committed
                # overwrite remote routes with local metadata and function schemas
                remote_only_updated = [
    
    James Briggs's avatar
    James Briggs committed
                    (
                        Utterance(
                            route=utt.route,
                            utterance=utt.utterance,
                            metadata=local_only_mapper[utt.route][1],
                            function_schemas=local_only_mapper[utt.route][0],
                        )
                        if utt.route in local_only_mapper
                        else utt
                    )
    
    James Briggs's avatar
    James Briggs committed
                    for utt in remote_only
                ]
                # propogate same to shared routes
                shared_updated = [
                    Utterance(
                        route=utt.route,
                        utterance=utt.utterance,
                        metadata=local_only_mapper[utt.route][1],
    
    James Briggs's avatar
    James Briggs committed
                        function_schemas=local_only_mapper[utt.route][0],
                    )
                    for utt in local_and_remote
                    if (
                        utt.route in local_only_mapper
                        and (
                            utt.metadata != local_only_mapper[utt.route][1]
                            or utt.function_schemas != local_only_mapper[utt.route][0]
    
                return {
                    "remote": {
    
    James Briggs's avatar
    James Briggs committed
                        "upsert": local_only + shared_updated + remote_only_updated,
    
    James Briggs's avatar
    James Briggs committed
                        "delete": [],
    
    James Briggs's avatar
    James Briggs committed
                    "local": {"upsert": remote_only_updated + shared_updated, "delete": []},
    
    James Briggs's avatar
    James Briggs committed
            else:
                raise ValueError(f"sync_mode must be one of {SYNC_MODES}")
    
    Anush008's avatar
    Anush008 committed
    class Metric(Enum):
    
    James Briggs's avatar
    James Briggs committed
        """The metric to use in vector-based similarity search indexes."""
    
    
    Anush008's avatar
    Anush008 committed
        COSINE = "cosine"
        DOTPRODUCT = "dotproduct"
        EUCLIDEAN = "euclidean"
        MANHATTAN = "manhattan"
    
    
    
    class SparseEmbedding(BaseModel):
    
        """Sparse embedding interface. Primarily uses numpy operations for faster
        operations.
        """
    
    jamescalam's avatar
    jamescalam committed
    
    
        embedding: np.ndarray
    
        class Config:
            arbitrary_types_allowed = True
    
        @classmethod
    
        def from_compact_array(cls, array: np.ndarray):
    
            """Create a SparseEmbedding object from a compact array.
    
            :param array: A compact array.
            :type array: np.ndarray
            :return: A SparseEmbedding object.
            :rtype: SparseEmbedding
            """
    
            if array.ndim != 2 or array.shape[1] != 2:
                raise ValueError(
                    f"Expected a 2D array with 2 columns, got a {array.ndim}D array with {array.shape[1]} columns. "
                    "Column 0 should contain index positions, and column 1 should contain respective values."
                )
            return cls(embedding=array)
    
    jamescalam's avatar
    jamescalam committed
    
    
        @classmethod
    
    jamescalam's avatar
    jamescalam committed
        def from_vector(cls, vector: np.ndarray):
    
            """Consumes an array of sparse vectors containing zero-values.
    
            :param vector: A sparse vector.
            :type vector: np.ndarray
            :return: A SparseEmbedding object.
            :rtype: SparseEmbedding
            """
    
    jamescalam's avatar
    jamescalam committed
            if vector.ndim != 1:
                raise ValueError(f"Expected a 1D array, got a {vector.ndim}D array.")
            return cls.from_compact_array(np.array([np.arange(len(vector)), vector]).T)
    
    jamescalam's avatar
    jamescalam committed
    
    
        def from_aurelio(cls, embedding: BM25SparseEmbedding):
    
            """Create a SparseEmbedding object from an AurelioSparseEmbedding object.
    
            :param embedding: An AurelioSparseEmbedding object.
            :type embedding: BM25SparseEmbedding
            :return: A SparseEmbedding object.
            :rtype: SparseEmbedding
            """
    
            arr = np.array([embedding.indices, embedding.values]).T
    
            return cls.from_compact_array(arr)
    
    jamescalam's avatar
    jamescalam committed
    
    
        @classmethod
        def from_dict(cls, sparse_dict: dict):
    
            """Create a SparseEmbedding object from a dictionary.
    
            :param sparse_dict: A dictionary of sparse values.
            :type sparse_dict: dict
            :return: A SparseEmbedding object.
            :rtype: SparseEmbedding
            """
    
            arr = np.array([list(sparse_dict.keys()), list(sparse_dict.values())]).T
    
            return cls.from_compact_array(arr)
    
    jamescalam's avatar
    jamescalam committed
    
    
    jamescalam's avatar
    jamescalam committed
        @classmethod
        def from_pinecone_dict(cls, sparse_dict: dict):
    
            """Create a SparseEmbedding object from a Pinecone dictionary.
    
            :param sparse_dict: A Pinecone dictionary.
            :type sparse_dict: dict
            :return: A SparseEmbedding object.
            :rtype: SparseEmbedding
            """
    
    jamescalam's avatar
    jamescalam committed
            arr = np.array([sparse_dict["indices"], sparse_dict["values"]]).T
            return cls.from_compact_array(arr)
    
    
        def to_dict(self):
    
            """Convert a SparseEmbedding object to a dictionary.
    
            :return: A dictionary of sparse values.
            :rtype: dict
            """
    
    jamescalam's avatar
    jamescalam committed
                i: v for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1])
    
    jamescalam's avatar
    jamescalam committed
    
    
        def to_pinecone(self):
    
            """Convert a SparseEmbedding object to a Pinecone dictionary.
    
            :return: A Pinecone dictionary.
            :rtype: dict
            """
    
            return {
                "indices": self.embedding[:, 0].astype(int).tolist(),
                "values": self.embedding[:, 1].tolist(),
            }
    
    jamescalam's avatar
    jamescalam committed
    
    
        # dictionary interface
        def items(self):
    
            """Return a list of (index, value) tuples from the SparseEmbedding object.
    
            :return: A list of (index, value) tuples.
            :rtype: list
            """
    
    jamescalam's avatar
    jamescalam committed
                (i, v)
                for i, v in zip(self.embedding[:, 0].astype(int), self.embedding[:, 1])