Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
semantic.py 3.16 KiB
from typing import Any, List, Optional

import numpy as np

from semantic_router.encoders import DenseEncoder
from semantic_router.index.base import BaseIndex
from semantic_router.llms import BaseLLM
from semantic_router.route import Route
from semantic_router.routers.base import BaseRouter
from semantic_router.utils.logger import logger


class SemanticRouter(BaseRouter):
    """A router that uses a dense encoder to encode routes and utterances.
    """
    def __init__(
        self,
        encoder: Optional[DenseEncoder] = None,
        llm: Optional[BaseLLM] = None,
        routes: Optional[List[Route]] = None,
        index: Optional[BaseIndex] = None,  # type: ignore
        top_k: int = 5,
        aggregation: str = "mean",
        auto_sync: Optional[str] = None,
    ):
        index = self._get_index(index=index)
        encoder = self._get_encoder(encoder=encoder)
        super().__init__(
            encoder=encoder,
            llm=llm,
            routes=routes if routes else [],
            index=index,
            top_k=top_k,
            aggregation=aggregation,
            auto_sync=auto_sync,
        )

    def _encode(self, text: list[str]) -> Any:
        """Given some text, encode it.

        :param text: The text to encode.
        :type text: list[str]
        :return: The encoded text.
        :rtype: Any
        """
        # create query vector
        xq = np.array(self.encoder(text))
        return xq

    async def _async_encode(self, text: list[str]) -> Any:
        """Given some text, encode it.

        :param text: The text to encode.
        :type text: list[str]
        :return: The encoded text.
        :rtype: Any
        """
        # create query vector
        xq = np.array(await self.encoder.acall(docs=text))
        return xq

    def add(self, routes: List[Route] | Route):
        """Add a route to the local SemanticRouter and index.

        :param route: The route to add.
        :type route: Route
        """
        current_local_hash = self._get_hash()
        current_remote_hash = self.index._read_hash()
        if current_remote_hash.value == "":
            # if remote hash is empty, the index is to be initialized
            current_remote_hash = current_local_hash
        if isinstance(routes, Route):
            routes = [routes]
        # create embeddings for all routes
        route_names, all_utterances, all_function_schemas, all_metadata = (
            self._extract_routes_details(routes, include_metadata=True)
        )
        dense_emb = self._encode(all_utterances)
        self.index.add(
            embeddings=dense_emb.tolist(),
            routes=route_names,
            utterances=all_utterances,
            function_schemas=all_function_schemas,
            metadata_list=all_metadata,
        )

        self.routes.extend(routes)
        if current_local_hash.value == current_remote_hash.value:
            self._write_hash()  # update current hash in index
        else:
            logger.warning(
                "Local and remote route layers were not aligned. Remote hash "
                f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
                "to see details."
            )