-
James Briggs authoredJames Briggs authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
semantic.py 3.16 KiB
from typing import Any, List, Optional
import numpy as np
from semantic_router.encoders import DenseEncoder
from semantic_router.index.base import BaseIndex
from semantic_router.llms import BaseLLM
from semantic_router.route import Route
from semantic_router.routers.base import BaseRouter
from semantic_router.utils.logger import logger
class SemanticRouter(BaseRouter):
"""A router that uses a dense encoder to encode routes and utterances.
"""
def __init__(
self,
encoder: Optional[DenseEncoder] = None,
llm: Optional[BaseLLM] = None,
routes: Optional[List[Route]] = None,
index: Optional[BaseIndex] = None, # type: ignore
top_k: int = 5,
aggregation: str = "mean",
auto_sync: Optional[str] = None,
):
index = self._get_index(index=index)
encoder = self._get_encoder(encoder=encoder)
super().__init__(
encoder=encoder,
llm=llm,
routes=routes if routes else [],
index=index,
top_k=top_k,
aggregation=aggregation,
auto_sync=auto_sync,
)
def _encode(self, text: list[str]) -> Any:
"""Given some text, encode it.
:param text: The text to encode.
:type text: list[str]
:return: The encoded text.
:rtype: Any
"""
# create query vector
xq = np.array(self.encoder(text))
return xq
async def _async_encode(self, text: list[str]) -> Any:
"""Given some text, encode it.
:param text: The text to encode.
:type text: list[str]
:return: The encoded text.
:rtype: Any
"""
# create query vector
xq = np.array(await self.encoder.acall(docs=text))
return xq
def add(self, routes: List[Route] | Route):
"""Add a route to the local SemanticRouter and index.
:param route: The route to add.
:type route: Route
"""
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
if isinstance(routes, Route):
routes = [routes]
# create embeddings for all routes
route_names, all_utterances, all_function_schemas, all_metadata = (
self._extract_routes_details(routes, include_metadata=True)
)
dense_emb = self._encode(all_utterances)
self.index.add(
embeddings=dense_emb.tolist(),
routes=route_names,
utterances=all_utterances,
function_schemas=all_function_schemas,
metadata_list=all_metadata,
)
self.routes.extend(routes)
if current_local_hash.value == current_remote_hash.value:
self._write_hash() # update current hash in index
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
"to see details."
)