diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index cad75daf5f0a010f384e7ee38575f25b7d1ceea1..d86e5350e124b93b7415c8faf707859ff722857e 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -47,6 +47,17 @@ class BaseIndex(BaseModel): ): """Add embeddings to the index. This method should be implemented by subclasses. + + :param embeddings: List of embeddings to add to the index. + :type embeddings: List[List[float]] + :param routes: List of routes to add to the index. + :type routes: List[str] + :param utterances: List of utterances to add to the index. + :type utterances: List[str] + :param function_schemas: List of function schemas to add to the index. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: List of metadata to add to the index. + :type metadata_list: List[Dict[str, Any]] """ raise NotImplementedError("This method should be implemented by subclasses.") @@ -61,6 +72,17 @@ class BaseIndex(BaseModel): ): """Add vectors to the index asynchronously. This method should be implemented by subclasses. + + :param embeddings: List of embeddings to add to the index. + :type embeddings: List[List[float]] + :param routes: List of routes to add to the index. + :type routes: List[str] + :param utterances: List of utterances to add to the index. + :type utterances: List[str] + :param function_schemas: List of function schemas to add to the index. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: List of metadata to add to the index. + :type metadata_list: List[Dict[str, Any]] """ logger.warning("Async method not implemented.") return self.add( @@ -143,6 +165,9 @@ class BaseIndex(BaseModel): """ Remove embeddings in a routes syncing process from the index. This method should be implemented by subclasses. + + :param routes_to_delete: Dictionary of routes to delete. + :type routes_to_delete: dict """ raise NotImplementedError("This method should be implemented by subclasses.") @@ -150,29 +175,38 @@ class BaseIndex(BaseModel): """ Remove embeddings in a routes syncing process from the index asynchronously. This method should be implemented by subclasses. + + :param routes_to_delete: Dictionary of routes to delete. + :type routes_to_delete: dict """ logger.warning("Async method not implemented.") return self._remove_and_sync(routes_to_delete=routes_to_delete) def delete(self, route_name: str): - """ - Deletes route by route name. + """Deletes route by route name. This method should be implemented by subclasses. + + :param route_name: Name of the route to delete. + :type route_name: str """ raise NotImplementedError("This method should be implemented by subclasses.") def describe(self) -> IndexConfig: - """ - Returns an IndexConfig object with index details such as type, dimensions, and - total vector count. + """Returns an IndexConfig object with index details such as type, dimensions, + and total vector count. This method should be implemented by subclasses. + + :return: An IndexConfig object. + :rtype: IndexConfig """ raise NotImplementedError("This method should be implemented by subclasses.") def is_ready(self) -> bool: - """ - Checks if the index is ready to be used. + """Checks if the index is ready to be used. This method should be implemented by subclasses. + + :return: True if the index is ready, False otherwise. + :rtype: bool """ raise NotImplementedError("This method should be implemented by subclasses.") @@ -183,9 +217,19 @@ class BaseIndex(BaseModel): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: - """ - Search the index for the query_vector and return top_k results. + """Search the index for the query_vector and return top_k results. This method should be implemented by subclasses. + + :param vector: The vector to search for. + :type vector: np.ndarray + :param top_k: The number of results to return. + :type top_k: int + :param route_filter: The routes to filter the search by. + :type route_filter: Optional[List[str]] + :param sparse_vector: The sparse vector to search for. + :type sparse_vector: dict[int, float] | SparseEmbedding | None + :return: A tuple containing the query vector and a list of route names. + :rtype: Tuple[np.ndarray, List[str]] """ raise NotImplementedError("This method should be implemented by subclasses.") @@ -196,9 +240,19 @@ class BaseIndex(BaseModel): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: - """ - Search the index for the query_vector and return top_k results. + """Search the index for the query_vector and return top_k results. This method should be implemented by subclasses. + + :param vector: The vector to search for. + :type vector: np.ndarray + :param top_k: The number of results to return. + :type top_k: int + :param route_filter: The routes to filter the search by. + :type route_filter: Optional[List[str]] + :param sparse_vector: The sparse vector to search for. + :type sparse_vector: dict[int, float] | SparseEmbedding | None + :return: A tuple containing the query vector and a list of route names. + :rtype: Tuple[np.ndarray, List[str]] """ raise NotImplementedError("This method should be implemented by subclasses.") @@ -214,8 +268,10 @@ class BaseIndex(BaseModel): raise NotImplementedError("This method should be implemented by subclasses.") def delete_all(self): - """ - Deletes all records from the index. + """Deletes all records from the index. + This method should be implemented by subclasses. + + :raises NotImplementedError: If the method is not implemented by the subclass. """ logger.warning("This method should be implemented by subclasses.") self.index = None @@ -223,9 +279,10 @@ class BaseIndex(BaseModel): self.utterances = None def delete_index(self): - """ - Deletes or resets the index. + """Deletes or resets the index. This method should be implemented by subclasses. + + :raises NotImplementedError: If the method is not implemented by the subclass. """ logger.warning("This method should be implemented by subclasses.") self.index = None @@ -404,9 +461,7 @@ class BaseIndex(BaseModel): return lock_param def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): - """ - Retrieves all vector IDs from the index. - + """Retrieves all vector IDs from the index. This method should be implemented by subclasses. :param prefix: The prefix to filter the vectors by. @@ -422,7 +477,6 @@ class BaseIndex(BaseModel): self, prefix: Optional[str] = None, include_metadata: bool = False ) -> tuple[list[str], list[dict]]: """Retrieves all vector IDs from the index asynchronously. - This method should be implemented by subclasses. :param prefix: The prefix to filter the vectors by. diff --git a/semantic_router/index/hybrid_local.py b/semantic_router/index/hybrid_local.py index c473a56c35a3436279090879cec4cd508f230da1..2a1441563853632fb135516162b7009997386aab 100644 --- a/semantic_router/index/hybrid_local.py +++ b/semantic_router/index/hybrid_local.py @@ -26,6 +26,21 @@ class HybridLocalIndex(LocalIndex): sparse_embeddings: Optional[List[SparseEmbedding]] = None, **kwargs, ): + """Add embeddings to the index. + + :param embeddings: List of embeddings to add to the index. + :type embeddings: List[List[float]] + :param routes: List of routes to add to the index. + :type routes: List[str] + :param utterances: List of utterances to add to the index. + :type utterances: List[str] + :param function_schemas: List of function schemas to add to the index. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: List of metadata to add to the index. + :type metadata_list: List[Dict[str, Any]] + :param sparse_embeddings: List of sparse embeddings to add to the index. + :type sparse_embeddings: Optional[List[SparseEmbedding]] + """ if sparse_embeddings is None: raise ValueError("Sparse embeddings are required for HybridLocalIndex.") if function_schemas is not None: @@ -73,12 +88,28 @@ class HybridLocalIndex(LocalIndex): def _sparse_dot_product( self, vec_a: dict[int, float], vec_b: dict[int, float] ) -> float: + """Calculate the dot product of two sparse vectors. + + :param vec_a: The first sparse vector. + :type vec_a: dict[int, float] + :param vec_b: The second sparse vector. + :type vec_b: dict[int, float] + :return: The dot product of the two sparse vectors. + :rtype: float + """ # switch vecs to ensure first is smallest for more efficiency if len(vec_a) > len(vec_b): vec_a, vec_b = vec_b, vec_a return sum(vec_a[i] * vec_b.get(i, 0) for i in vec_a) def _sparse_index_dot_product(self, vec_a: dict[int, float]) -> list[float]: + """Calculate the dot product of a sparse vector and a list of sparse vectors. + + :param vec_a: The sparse vector. + :type vec_a: dict[int, float] + :return: A list of dot products. + :rtype: list[float] + """ if self.sparse_index is None: raise ValueError("self.sparse_index is not populated.") dot_products = [ @@ -163,14 +194,26 @@ class HybridLocalIndex(LocalIndex): ) def aget_routes(self): + """Get all routes from the index. + + :return: A list of routes. + :rtype: List[str] + """ logger.error(f"Sync remove is not implemented for {self.__class__.__name__}.") def _write_config(self, config: ConfigParameter): + """Write the config to the index. + + :param config: The config to write to the index. + :type config: ConfigParameter + """ logger.warning(f"No config is written for {self.__class__.__name__}.") def delete(self, route_name: str): - """ - Delete all records of a specific route from the index. + """Delete all records of a specific route from the index. + + :param route_name: The name of the route to delete. + :type route_name: str """ if ( self.index is not None @@ -188,15 +231,23 @@ class HybridLocalIndex(LocalIndex): ) def delete_index(self): - """ - Deletes the index, effectively clearing it and setting it to None. + """Deletes the index, effectively clearing it and setting it to None. + + :return: None + :rtype: None """ self.index = None self.routes = None self.utterances = None def _get_indices_for_route(self, route_name: str): - """Gets an array of indices for a specific route.""" + """Gets an array of indices for a specific route. + + :param route_name: The name of the route to get indices for. + :type route_name: str + :return: An array of indices for the route. + :rtype: np.ndarray + """ if self.routes is None: raise ValueError("Routes are not populated.") idx = [i for i, route in enumerate(self.routes) if route == route_name] diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index d66f6159e57f918c772f31fd9c8ae5f86311063b..1c70daae29897142495bfe93c503d4ed39213bb1 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -27,6 +27,19 @@ class LocalIndex(BaseIndex): metadata_list: List[Dict[str, Any]] = [], **kwargs, ): + """Add embeddings to the index. + + :param embeddings: List of embeddings to add to the index. + :type embeddings: List[List[float]] + :param routes: List of routes to add to the index. + :type routes: List[str] + :param utterances: List of utterances to add to the index. + :type utterances: List[str] + :param function_schemas: List of function schemas to add to the index. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: List of metadata to add to the index. + :type metadata_list: List[Dict[str, Any]] + """ embeds = np.array(embeddings) # type: ignore routes_arr = np.array(routes) if isinstance(utterances[0], str): @@ -43,6 +56,13 @@ class LocalIndex(BaseIndex): self.utterances = np.concatenate([self.utterances, utterances_arr]) def _remove_and_sync(self, routes_to_delete: dict) -> np.ndarray: + """Remove and sync the index. + + :param routes_to_delete: Dictionary of routes to delete. + :type routes_to_delete: dict + :return: A numpy array of the removed route utterances. + :rtype: np.ndarray + """ if self.index is None or self.routes is None or self.utterances is None: raise ValueError("Index, routes, or utterances are not populated.") # TODO JB: implement routes and utterances as a numpy array @@ -77,6 +97,11 @@ class LocalIndex(BaseIndex): return [Utterance.from_tuple(x) for x in zip(self.routes, self.utterances)] def describe(self) -> IndexConfig: + """Describe the index. + + :return: An IndexConfig object. + :rtype: IndexConfig + """ return IndexConfig( type=self.type, dimensions=self.index.shape[1] if self.index is not None else 0, @@ -84,8 +109,10 @@ class LocalIndex(BaseIndex): ) def is_ready(self) -> bool: - """ - Checks if the index is ready to be used. + """Checks if the index is ready to be used. + + :return: True if the index is ready, False otherwise. + :rtype: bool """ return self.index is not None and self.routes is not None @@ -96,8 +123,18 @@ class LocalIndex(BaseIndex): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: - """ - Search the index for the query and return top_k results. + """Search the index for the query and return top_k results. + + :param vector: The vector to search for. + :type vector: np.ndarray + :param top_k: The number of results to return. + :type top_k: int + :param route_filter: The routes to filter the search by. + :type route_filter: Optional[List[str]] + :param sparse_vector: The sparse vector to search for. + :type sparse_vector: dict[int, float] | SparseEmbedding | None + :return: A tuple containing the query vector and a list of route names. + :rtype: Tuple[np.ndarray, List[str]] """ if self.index is None or self.routes is None: raise ValueError("Index or routes are not populated.") @@ -126,8 +163,18 @@ class LocalIndex(BaseIndex): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: - """ - Search the index for the query and return top_k results. + """Search the index for the query and return top_k results. + + :param vector: The vector to search for. + :type vector: np.ndarray + :param top_k: The number of results to return. + :type top_k: int + :param route_filter: The routes to filter the search by. + :type route_filter: Optional[List[str]] + :param sparse_vector: The sparse vector to search for. + :type sparse_vector: dict[int, float] | SparseEmbedding | None + :return: A tuple containing the query vector and a list of route names. + :rtype: Tuple[np.ndarray, List[str]] """ if self.index is None or self.routes is None: raise ValueError("Index or routes are not populated.") @@ -150,14 +197,26 @@ class LocalIndex(BaseIndex): return scores, route_names def aget_routes(self): + """Get all routes from the index. + + :return: A list of routes. + :rtype: List[str] + """ logger.error("Sync remove is not implemented for LocalIndex.") def _write_config(self, config: ConfigParameter): + """Write the config to the index. + + :param config: The config to write to the index. + :type config: ConfigParameter + """ logger.warning("No config is written for LocalIndex.") def delete(self, route_name: str): - """ - Delete all records of a specific route from the index. + """Delete all records of a specific route from the index. + + :param route_name: The name of the route to delete. + :type route_name: str """ if ( self.index is not None @@ -175,15 +234,23 @@ class LocalIndex(BaseIndex): ) def delete_index(self): - """ - Deletes the index, effectively clearing it and setting it to None. + """Deletes the index, effectively clearing it and setting it to None. + + :return: None + :rtype: None """ self.index = None self.routes = None self.utterances = None def _get_indices_for_route(self, route_name: str): - """Gets an array of indices for a specific route.""" + """Gets an array of indices for a specific route. + + :param route_name: The name of the route to get indices for. + :type route_name: str + :return: An array of indices for the route. + :rtype: np.ndarray + """ if self.routes is None: raise ValueError("Routes are not populated.") idx = [i for i, route in enumerate(self.routes) if route == route_name] diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 62f5ac2212177c79407d55c95cbaabaa0d4544a0..e26bec5cdd52c0da4d33f6516afbf1eedbe0783d 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -35,6 +35,23 @@ def build_records( metadata_list: List[Dict[str, Any]] = [], sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None, ) -> List[Dict]: + """Build records for Pinecone upsert. + + :param embeddings: List of embeddings to upsert. + :type embeddings: List[List[float]] + :param routes: List of routes to upsert. + :type routes: List[str] + :param utterances: List of utterances to upsert. + :type utterances: List[str] + :param function_schemas: List of function schemas to upsert. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: List of metadata to upsert. + :type metadata_list: List[Dict[str, Any]] + :param sparse_embeddings: List of sparse embeddings to upsert. + :type sparse_embeddings: Optional[List[SparseEmbedding]] + :return: List of records to upsert. + :rtype: List[Dict] + """ if function_schemas is None: function_schemas = [{}] * len(embeddings) if sparse_embeddings is None: @@ -86,6 +103,11 @@ class PineconeRecord(BaseModel): metadata: Dict[str, Any] = {} # Additional metadata dictionary def __init__(self, **data): + """Initialize PineconeRecord. + + :param **data: Keyword arguments to pass to the BaseModel constructor. + :type **data: dict + """ super().__init__(**data) clean_route = clean_route_name(self.route) # Use SHA-256 for a more secure hash @@ -100,6 +122,11 @@ class PineconeRecord(BaseModel): ) def to_dict(self): + """Convert PineconeRecord to a dictionary. + + :return: Dictionary representation of the PineconeRecord. + :rtype: dict + """ d = { "id": self.id, "values": self.values, @@ -140,6 +167,29 @@ class PineconeIndex(BaseIndex): base_url: Optional[str] = "https://api.pinecone.io", init_async_index: bool = False, ): + """Initialize PineconeIndex. + + :param api_key: Pinecone API key. + :type api_key: Optional[str] + :param index_name: Name of the index. + :type index_name: str + :param dimensions: Dimensions of the index. + :type dimensions: Optional[int] + :param metric: Metric of the index. + :type metric: str + :param cloud: Cloud provider of the index. + :type cloud: str + :param region: Region of the index. + :type region: str + :param host: Host of the index. + :type host: str + :param namespace: Namespace of the index. + :type namespace: Optional[str] + :param base_url: Base URL of the Pinecone API. + :type base_url: Optional[str] + :param init_async_index: Whether to initialize the index asynchronously. + :type init_async_index: bool + """ super().__init__() self.api_key = api_key or os.getenv("PINECONE_API_KEY") if not self.api_key: @@ -182,6 +232,13 @@ class PineconeIndex(BaseIndex): self.index = self._init_index() def _initialize_client(self, api_key: Optional[str] = None): + """Initialize the Pinecone client. + + :param api_key: Pinecone API key. + :type api_key: Optional[str] + :return: Pinecone client. + :rtype: Pinecone + """ try: from pinecone import Pinecone, ServerlessSpec @@ -203,6 +260,12 @@ class PineconeIndex(BaseIndex): return Pinecone(**pinecone_args) def _calculate_index_host(self): + """Calculate the index host. Used to differentiate between normal + Pinecone and Pinecone Local instance. + + :return: None + :rtype: None + """ if self.index_host and self.base_url: if "api.pinecone.io" in self.base_url: if not self.index_host.startswith("http"): @@ -285,6 +348,20 @@ class PineconeIndex(BaseIndex): return index async def _init_async_index(self, force_create: bool = False): + """Initializing the index can be done after the object has been created + to allow for the user to set the dimensions and other parameters. + + If the index doesn't exist and the dimensions are given, the index will + be created. If the index exists, it will be returned. If the index doesn't + exist and the dimensions are not given, the index will not be created and + None will be returned. + + This method is used to initialize the index asynchronously. + + :param force_create: If True, the index will be created even if the + dimensions are not given (which will raise an error). + :type force_create: bool, optional + """ index_stats = None indexes = await self._async_list_indexes() index_names = [i["name"] for i in indexes["indexes"]] @@ -344,7 +421,23 @@ class PineconeIndex(BaseIndex): sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None, **kwargs, ): - """Add vectors to Pinecone in batches.""" + """Add vectors to Pinecone in batches. + + :param embeddings: List of embeddings to upsert. + :type embeddings: List[List[float]] + :param routes: List of routes to upsert. + :type routes: List[str] + :param utterances: List of utterances to upsert. + :type utterances: List[str] + :param function_schemas: List of function schemas to upsert. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: List of metadata to upsert. + :type metadata_list: List[Dict[str, Any]] + :param batch_size: Number of vectors to upsert in a single batch. + :type batch_size: int, optional + :param sparse_embeddings: List of sparse embeddings to upsert. + :type sparse_embeddings: Optional[List[SparseEmbedding]] + """ if self.index is None: self.dimensions = self.dimensions or len(embeddings[0]) self.index = self._init_index(force_create=True) @@ -371,7 +464,23 @@ class PineconeIndex(BaseIndex): sparse_embeddings: Optional[Optional[List[SparseEmbedding]]] = None, **kwargs, ): - """Add vectors to Pinecone in batches.""" + """Add vectors to Pinecone in batches. + + :param embeddings: List of embeddings to upsert. + :type embeddings: List[List[float]] + :param routes: List of routes to upsert. + :type routes: List[str] + :param utterances: List of utterances to upsert. + :type utterances: List[str] + :param function_schemas: List of function schemas to upsert. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: List of metadata to upsert. + :type metadata_list: List[Dict[str, Any]] + :param batch_size: Number of vectors to upsert in a single batch. + :type batch_size: int, optional + :param sparse_embeddings: List of sparse embeddings to upsert. + :type sparse_embeddings: Optional[List[SparseEmbedding]] + """ if self.index is None: self.dimensions = self.dimensions or len(embeddings[0]) self.index = await self._init_async_index(force_create=True) @@ -392,6 +501,11 @@ class PineconeIndex(BaseIndex): ) def _remove_and_sync(self, routes_to_delete: dict): + """Remove specified routes from index if they exist. + + :param routes_to_delete: Routes to delete. + :type routes_to_delete: dict + """ for route, utterances in routes_to_delete.items(): remote_routes = self._get_routes_with_ids(route_name=route) ids_to_delete = [ @@ -404,6 +518,13 @@ class PineconeIndex(BaseIndex): self.index.delete(ids=ids_to_delete, namespace=self.namespace) async def _async_remove_and_sync(self, routes_to_delete: dict): + """Remove specified routes from index if they exist. + + This method is asyncronous. + + :param routes_to_delete: Routes to delete. + :type routes_to_delete: dict + """ for route, utterances in routes_to_delete.items(): remote_routes = await self._async_get_routes_with_ids(route_name=route) ids_to_delete = [ @@ -418,16 +539,37 @@ class PineconeIndex(BaseIndex): ) def _get_route_ids(self, route_name: str): + """Get the IDs of the routes in the index. + + :param route_name: Name of the route to get the IDs for. + :type route_name: str + :return: List of IDs of the routes. + :rtype: list[str] + """ clean_route = clean_route_name(route_name) ids, _ = self._get_all(prefix=f"{clean_route}#") return ids async def _async_get_route_ids(self, route_name: str): + """Get the IDs of the routes in the index. + + :param route_name: Name of the route to get the IDs for. + :type route_name: str + :return: List of IDs of the routes. + :rtype: list[str] + """ clean_route = clean_route_name(route_name) ids, _ = await self._async_get_all(prefix=f"{clean_route}#") return ids def _get_routes_with_ids(self, route_name: str): + """Get the routes with their IDs from the index. + + :param route_name: Name of the route to get the routes with their IDs for. + :type route_name: str + :return: List of routes with their IDs. + :rtype: list[dict] + """ clean_route = clean_route_name(route_name) ids, metadata = self._get_all(prefix=f"{clean_route}#", include_metadata=True) route_tuples = [] @@ -442,6 +584,13 @@ class PineconeIndex(BaseIndex): return route_tuples async def _async_get_routes_with_ids(self, route_name: str): + """Get the routes with their IDs from the index. + + :param route_name: Name of the route to get the routes with their IDs for. + :type route_name: str + :return: List of routes with their IDs. + :rtype: list[dict] + """ clean_route = clean_route_name(route_name) ids, metadata = await self._async_get_all( prefix=f"{clean_route}#", include_metadata=True @@ -454,8 +603,7 @@ class PineconeIndex(BaseIndex): return route_tuples def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): - """ - Retrieves all vector IDs from the Pinecone index using pagination. + """Retrieves all vector IDs from the Pinecone index using pagination. :param prefix: The prefix to filter the vectors by. :type prefix: Optional[str] @@ -486,6 +634,11 @@ class PineconeIndex(BaseIndex): return all_vector_ids, metadata def delete(self, route_name: str): + """Delete specified route from index if it exists. + + :param route_name: Name of the route to delete. + :type route_name: str + """ route_vec_ids = self._get_route_ids(route_name=route_name) if self.index is not None: logger.info("index is not None, deleting...") @@ -515,12 +668,22 @@ class PineconeIndex(BaseIndex): raise ValueError("Index is None, could not delete.") def delete_all(self): + """Delete all routes from index if it exists. + + :return: None + :rtype: None + """ if self.index is not None: self.index.delete(delete_all=True, namespace=self.namespace) else: raise ValueError("Index is None, could not delete.") def describe(self) -> IndexConfig: + """Describe the index. + + :return: IndexConfig + :rtype: IndexConfig + """ if self.index is not None: stats = self.index.describe_index_stats() return IndexConfig( @@ -536,8 +699,10 @@ class PineconeIndex(BaseIndex): ) def is_ready(self) -> bool: - """ - Checks if the index is ready to be used. + """Checks if the index is ready to be used. + + :return: True if the index is ready, False otherwise. + :rtype: bool """ return self.index is not None @@ -602,6 +767,15 @@ class PineconeIndex(BaseIndex): return np.array(scores), route_names def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + """Read a config parameter from the index. + + :param field: The field to read. + :type field: str + :param scope: The scope to read. + :type scope: str | None + :return: The config parameter that was read. + :rtype: ConfigParameter + """ scope = scope or self.namespace if self.index is None: return ConfigParameter( @@ -716,8 +890,7 @@ class PineconeIndex(BaseIndex): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: - """ - Asynchronously search the index for the query vector and return the top_k results. + """Asynchronously search the index for the query vector and return the top_k results. :param vector: The query vector to search for. :type vector: np.ndarray @@ -773,6 +946,11 @@ class PineconeIndex(BaseIndex): return await self._async_get_routes() def delete_index(self): + """Delete the index. + + :return: None + :rtype: None + """ self.client.delete_index(self.index_name) self.index = None @@ -786,6 +964,21 @@ class PineconeIndex(BaseIndex): top_k: int = 5, include_metadata: bool = False, ): + """Asynchronously query the index for the query vector and return the top_k results. + + :param vector: The query vector to search for. + :type vector: list[float] + :param sparse_vector: The sparse vector to search for. + :type sparse_vector: dict[str, Any] | None + :param namespace: The namespace to search for. + :type namespace: str + :param filter: The filter to search for. + :type filter: Optional[dict] + :param top_k: The number of top results to return, defaults to 5. + :type top_k: int, optional + :param include_metadata: Whether to include metadata in the results, defaults to False. + :type include_metadata: bool, optional + """ params = { "vector": vector, "sparse_vector": sparse_vector, @@ -824,6 +1017,11 @@ class PineconeIndex(BaseIndex): return {} async def _async_list_indexes(self): + """Asynchronously lists all indexes within the current Pinecone project. + + :return: List of indexes. + :rtype: list[dict] + """ async with aiohttp.ClientSession() as session: async with session.get( f"{self.base_url}/indexes", @@ -836,6 +1034,13 @@ class PineconeIndex(BaseIndex): vectors: list[dict], namespace: str = "", ): + """Asynchronously upserts vectors into the index. + + :param vectors: The vectors to upsert. + :type vectors: list[dict] + :param namespace: The namespace to upsert the vectors into. + :type namespace: str + """ params = { "vectors": vectors, "namespace": namespace, @@ -865,6 +1070,19 @@ class PineconeIndex(BaseIndex): region: str, metric: str = "dotproduct", ): + """Asynchronously creates a new index in Pinecone. + + :param name: The name of the index to create. + :type name: str + :param dimension: The dimension of the index. + :type dimension: int + :param cloud: The cloud provider to create the index on. + :type cloud: str + :param region: The region to create the index in. + :type region: str + :param metric: The metric to use for the index, defaults to "dotproduct". + :type metric: str, optional + """ params = { "name": name, "dimension": dimension, @@ -880,6 +1098,13 @@ class PineconeIndex(BaseIndex): return await response.json(content_type=None) async def _async_delete(self, ids: list[str], namespace: str = ""): + """Asynchronously deletes vectors from the index. + + :param ids: The IDs of the vectors to delete. + :type ids: list[str] + :param namespace: The namespace to delete the vectors from. + :type namespace: str + """ params = { "ids": ids, "namespace": namespace, @@ -900,6 +1125,11 @@ class PineconeIndex(BaseIndex): return await response.json(content_type=None) async def _async_describe_index(self, name: str): + """Asynchronously describes the index. + + :param name: The name of the index to describe. + :type name: str + """ async with aiohttp.ClientSession() as session: async with session.get( f"{self.base_url}/indexes/{name}", diff --git a/semantic_router/index/postgres.py b/semantic_router/index/postgres.py index 362741efc1bc2f48f0039f31b5f7b449bbf672c4..5700e9a40b13f341e6fa2d8066e0b0b00c67b71a 100644 --- a/semantic_router/index/postgres.py +++ b/semantic_router/index/postgres.py @@ -15,8 +15,7 @@ if TYPE_CHECKING: class MetricPgVecOperatorMap(Enum): - """ - Enum to map the metric to PostgreSQL vector operators. + """Enum to map the metric to PostgreSQL vector operators. """ cosine = "<=>" @@ -26,8 +25,7 @@ class MetricPgVecOperatorMap(Enum): def parse_vector(vector_str: Union[str, Any]) -> List[float]: - """ - Parses a vector from a string or other representation. + """Parses a vector from a string or other representation. :param vector_str: The string or object representation of a vector. :type vector_str: Union[str, Any] @@ -42,8 +40,7 @@ def parse_vector(vector_str: Union[str, Any]) -> List[float]: def clean_route_name(route_name: str) -> str: - """ - Cleans and formats the route name by stripping spaces and replacing them with hyphens. + """Cleans and formats the route name by stripping spaces and replacing them with hyphens. :param route_name: The original route name. :type route_name: str @@ -54,8 +51,7 @@ def clean_route_name(route_name: str) -> str: class PostgresIndexRecord(BaseModel): - """ - Model to represent a record in the Postgres index. + """Model to represent a record in the Postgres index. """ id: str = "" @@ -64,8 +60,7 @@ class PostgresIndexRecord(BaseModel): vector: List[float] def __init__(self, **data) -> None: - """ - Initializes a new Postgres index record with given data. + """Initializes a new Postgres index record with given data. :param data: Field values for the record. :type data: dict @@ -81,8 +76,7 @@ class PostgresIndexRecord(BaseModel): self.id = clean_route + "#" + str(hashed_uuid) def to_dict(self) -> Dict: - """ - Converts the record to a dictionary. + """Converts the record to a dictionary. :return: A dictionary representation of the record. :rtype: Dict @@ -96,8 +90,7 @@ class PostgresIndexRecord(BaseModel): class PostgresIndex(BaseIndex): - """ - Postgres implementation of Index. + """Postgres implementation of Index. """ connection_string: Optional[str] = None @@ -118,8 +111,7 @@ class PostgresIndex(BaseIndex): namespace: Optional[str] = "", dimensions: int | None = None, ): - """ - Initializes the Postgres index with the specified parameters. + """Initializes the Postgres index with the specified parameters. :param connection_string: The connection string for the PostgreSQL database. :type connection_string: Optional[str] @@ -170,8 +162,7 @@ class PostgresIndex(BaseIndex): return f"{self.index_prefix}{self.index_name}" def _get_metric_operator(self) -> str: - """ - Returns the PostgreSQL operator for the specified metric. + """Returns the PostgreSQL operator for the specified metric. :return: The PostgreSQL operator. :rtype: str @@ -179,8 +170,7 @@ class PostgresIndex(BaseIndex): return MetricPgVecOperatorMap[self.metric.value].value def _get_score_query(self, embeddings_str: str) -> str: - """ - Creates the select statement required to return the embeddings distance. + """Creates the select statement required to return the embeddings distance. :param embeddings_str: The string representation of the embeddings. :type embeddings_str: str @@ -200,8 +190,7 @@ class PostgresIndex(BaseIndex): raise ValueError(f"Unsupported metric: {self.metric}") def setup_index(self) -> None: - """ - Sets up the index by creating the table and vector extension if they do not exist. + """Sets up the index by creating the table and vector extension if they do not exist. :raises ValueError: If the existing table's vector dimensions do not match the expected dimensions. :raises TypeError: If the database connection is not established. @@ -229,8 +218,8 @@ class PostgresIndex(BaseIndex): self.conn.commit() def _check_embeddings_dimensions(self) -> bool: - """ - Checks if the length of the vector embeddings in the table matches the expected dimensions, or if no table exists. + """Checks if the length of the vector embeddings in the table matches the expected + dimensions, or if no table exists. :return: True if the dimensions match or the table does not exist, False otherwise. :rtype: bool @@ -275,8 +264,7 @@ class PostgresIndex(BaseIndex): metadata_list: List[Dict[str, Any]] = [], **kwargs, ) -> None: - """ - Adds vectors to the index. + """Adds records to the index. :param embeddings: A list of vector embeddings to add. :type embeddings: List[List[float]] @@ -284,6 +272,10 @@ class PostgresIndex(BaseIndex): :type routes: List[str] :param utterances: A list of utterances corresponding to the embeddings. :type utterances: List[Any] + :param function_schemas: A list of function schemas corresponding to the embeddings. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: A list of metadata corresponding to the embeddings. + :type metadata_list: List[Dict[str, Any]] :raises ValueError: If the vector embeddings being added do not match the expected dimensions. :raises TypeError: If the database connection is not established. """ @@ -310,8 +302,7 @@ class PostgresIndex(BaseIndex): self.conn.commit() def delete(self, route_name: str) -> None: - """ - Deletes records with the specified route name. + """Deletes records with the specified route name. :param route_name: The name of the route to delete records for. :type route_name: str @@ -325,8 +316,7 @@ class PostgresIndex(BaseIndex): self.conn.commit() def describe(self) -> IndexConfig: - """ - Describes the index by returning its type, dimensions, and total vector count. + """Describes the index by returning its type, dimensions, and total vector count. :return: An IndexConfig object containing the index's type, dimensions, and total vector count. :rtype: IndexConfig @@ -353,8 +343,10 @@ class PostgresIndex(BaseIndex): ) def is_ready(self) -> bool: - """ - Checks if the index is ready to be used. + """Checks if the index is ready to be used. + + :return: True if the index is ready, False otherwise. + :rtype: bool """ return isinstance(self.conn, psycopg2.extensions.connection) @@ -365,8 +357,7 @@ class PostgresIndex(BaseIndex): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: - """ - Searches the index for the query vector and returns the top_k results. + """Searches the index for the query vector and returns the top_k results. :param vector: The query vector. :type vector: np.ndarray @@ -374,6 +365,8 @@ class PostgresIndex(BaseIndex): :type top_k: int :param route_filter: Optional list of routes to filter the results by. :type route_filter: Optional[List[str]] + :param sparse_vector: Optional sparse vector to filter the results by. + :type sparse_vector: dict[int, float] | SparseEmbedding | None :return: A tuple containing the scores and routes of the top_k results. :rtype: Tuple[np.ndarray, List[str]] :raises TypeError: If the database connection is not established. @@ -396,8 +389,7 @@ class PostgresIndex(BaseIndex): ] def _get_route_ids(self, route_name: str): - """ - Retrieves all vector IDs for a specific route. + """Retrieves all vector IDs for a specific route. :param route_name: The name of the route to retrieve IDs for. :type route_name: str @@ -411,8 +403,7 @@ class PostgresIndex(BaseIndex): def _get_all( self, route_name: Optional[str] = None, include_metadata: bool = False ): - """ - Retrieves all vector IDs and optionally metadata from the Postgres index. + """Retrieves all vector IDs and optionally metadata from the Postgres index. :param route_name: Optional route name to filter the results by. :type route_name: Optional[str] @@ -448,8 +439,7 @@ class PostgresIndex(BaseIndex): return all_vector_ids, metadata def delete_all(self): - """ - Deletes all records from the Postgres index. + """Deletes all records from the Postgres index. :raises TypeError: If the database connection is not established. """ @@ -461,8 +451,7 @@ class PostgresIndex(BaseIndex): self.conn.commit() def delete_index(self) -> None: - """ - Deletes the entire table for the index. + """Deletes the entire table for the index. :raises TypeError: If the database connection is not established. """ @@ -474,14 +463,25 @@ class PostgresIndex(BaseIndex): self.conn.commit() def aget_routes(self): + """Asynchronously get all routes from the index. + + Not yet implemented for PostgresIndex. + + :return: A list of routes. + :rtype: List[str] + """ raise NotImplementedError("Async get is not implemented for PostgresIndex.") def _write_config(self, config: ConfigParameter): + """Write the config to the index. + + :param config: The config to write to the index. + :type config: ConfigParameter + """ logger.warning("No config is written for PostgresIndex.") def __len__(self): - """ - Returns the total number of vectors in the index. + """Returns the total number of vectors in the index. :return: The total number of vectors. :rtype: int @@ -498,8 +498,7 @@ class PostgresIndex(BaseIndex): return count[0] class Config: - """ - Configuration for the Pydantic BaseModel. + """Configuration for the Pydantic BaseModel. """ arbitrary_types_allowed = True diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 42ef48c0111b260d0d9a238a6b105da76806ce04..7e6f0748668fef90d28135a625ef697b152a0fe9 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -95,6 +95,11 @@ class QdrantIndex(BaseIndex): self.client, self.aclient = self._initialize_clients() def _initialize_clients(self): + """Initialize the clients for the Qdrant index. + + :return: A tuple of the sync and async clients. + :rtype: Tuple[QdrantClient, Optional[AsyncQdrantClient]] + """ try: from qdrant_client import AsyncQdrantClient, QdrantClient @@ -142,6 +147,11 @@ class QdrantIndex(BaseIndex): ) from e def _init_collection(self) -> None: + """Initialize the collection for the Qdrant index. + + :return: None + :rtype: None + """ from qdrant_client import QdrantClient, models self.client: QdrantClient @@ -160,6 +170,11 @@ class QdrantIndex(BaseIndex): ) def _remove_and_sync(self, routes_to_delete: dict): + """Remove and sync the index. + + :param routes_to_delete: The routes to delete. + :type routes_to_delete: dict + """ logger.error("Sync remove is not implemented for QdrantIndex.") def add( @@ -172,6 +187,21 @@ class QdrantIndex(BaseIndex): batch_size: int = DEFAULT_UPLOAD_BATCH_SIZE, **kwargs, ): + """Add records to the index. + + :param embeddings: The embeddings to add. + :type embeddings: List[List[float]] + :param routes: The routes to add. + :type routes: List[str] + :param utterances: The utterances to add. + :type utterances: List[str] + :param function_schemas: The function schemas to add. + :type function_schemas: Optional[List[Dict[str, Any]]] + :param metadata_list: The metadata to add. + :type metadata_list: List[Dict[str, Any]] + :param batch_size: The batch size to use for the upload. + :type batch_size: int + """ self.dimensions = self.dimensions or len(embeddings[0]) self._init_collection() @@ -239,6 +269,11 @@ class QdrantIndex(BaseIndex): return utterances def delete(self, route_name: str): + """Delete records from the index. + + :param route_name: The name of the route to delete. + :type route_name: str + """ from qdrant_client import models self.client.delete( @@ -254,6 +289,11 @@ class QdrantIndex(BaseIndex): ) def describe(self) -> IndexConfig: + """Describe the index. + + :return: The index configuration. + :rtype: IndexConfig + """ collection_info = self.client.get_collection(self.index_name) return IndexConfig( @@ -263,8 +303,10 @@ class QdrantIndex(BaseIndex): ) def is_ready(self) -> bool: - """ - Checks if the index is ready to be used. + """Checks if the index is ready to be used. + + :return: True if the index is ready, False otherwise. + :rtype: bool """ return self.client.collection_exists(self.index_name) @@ -275,6 +317,19 @@ class QdrantIndex(BaseIndex): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: + """Query the index. + + :param vector: The vector to query. + :type vector: np.ndarray + :param top_k: The number of results to return. + :type top_k: int + :param route_filter: The route filter to apply. + :type route_filter: Optional[List[str]] + :param sparse_vector: The sparse vector to query. + :type sparse_vector: dict[int, float] | SparseEmbedding | None + :return: A tuple of the scores and route names. + :rtype: Tuple[np.ndarray, List[str]] + """ from qdrant_client import QdrantClient, models self.client: QdrantClient @@ -309,6 +364,19 @@ class QdrantIndex(BaseIndex): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> Tuple[np.ndarray, List[str]]: + """Asynchronously query the index. + + :param vector: The vector to query. + :type vector: np.ndarray + :param top_k: The number of results to return. + :type top_k: int + :param route_filter: The route filter to apply. + :type route_filter: Optional[List[str]] + :param sparse_vector: The sparse vector to query. + :type sparse_vector: dict[int, float] | SparseEmbedding | None + :return: A tuple of the scores and route names. + :rtype: Tuple[np.ndarray, List[str]] + """ from qdrant_client import AsyncQdrantClient, models self.aclient: Optional[AsyncQdrantClient] @@ -341,12 +409,29 @@ class QdrantIndex(BaseIndex): return np.array(scores), route_names def aget_routes(self): + """Asynchronously get all routes from the index. + + :return: A list of routes. + :rtype: List[str] + """ logger.error("Sync remove is not implemented for QdrantIndex.") def delete_index(self): + """Delete the index. + + :return: None + :rtype: None + """ self.client.delete_collection(self.index_name) def convert_metric(self, metric: Metric): + """Convert the metric to a Qdrant distance metric. + + :param metric: The metric to convert. + :type metric: Metric + :return: The converted metric. + :rtype: Distance + """ from qdrant_client.models import Distance mapping = { @@ -362,6 +447,11 @@ class QdrantIndex(BaseIndex): return mapping[metric] def _write_config(self, config: ConfigParameter): + """Write the config to the index. + + :param config: The config to write to the index. + :type config: ConfigParameter + """ logger.warning("No config is written for QdrantIndex.") def __len__(self): diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index ca3545fecc95857bf5d1a2fe076c15f3fe5776bd..af43913317cb0b14b45378919ee5ef01cee57a40 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -65,9 +65,7 @@ def is_valid(layer_config: str) -> bool: class RouterConfig: - """ - Generates a RouterConfig object that can be used for initializing a - Routers. + """Generates a RouterConfig object that can be used for initializing routers. """ routes: List[Route] = Field(default_factory=list) @@ -81,6 +79,15 @@ class RouterConfig: encoder_type: str = "openai", encoder_name: Optional[str] = None, ): + """Initialize a RouterConfig object. + + :param routes: A list of routes. + :type routes: List[Route] + :param encoder_type: The type of encoder to use. + :type encoder_type: str + :param encoder_name: The name of the encoder to use. + :type encoder_name: Optional[str] + """ self.encoder_type = encoder_type if encoder_name is None: for encode_type in EncoderType: @@ -99,6 +106,12 @@ class RouterConfig: @classmethod def from_file(cls, path: str) -> "RouterConfig": + """Initialize a RouterConfig from a file. Expects a JSON or YAML file with file + extension .json, .yaml, or .yml. + + :param path: The path to the file to load the RouterConfig from. + :type path: str + """ logger.info(f"Loading route config from {path}") _, ext = os.path.splitext(path) with open(path, "r") as f: @@ -206,6 +219,11 @@ class RouterConfig: ) def to_dict(self) -> Dict[str, Any]: + """Convert the RouterConfig to a dictionary. + + :return: A dictionary representation of the RouterConfig. + :rtype: Dict[str, Any] + """ return { "encoder_type": self.encoder_type, "encoder_name": self.encoder_name, @@ -213,7 +231,11 @@ class RouterConfig: } def to_file(self, path: str): - """Save the routes to a file in JSON or YAML format""" + """Save the routes to a file in JSON or YAML format. + + :param path: The path to save the RouterConfig to. + :type path: str + """ logger.info(f"Saving route config to {path}") _, ext = os.path.splitext(path) @@ -266,6 +288,13 @@ class RouterConfig: logger.info(f"Added route `{route.name}`") def get(self, name: str) -> Optional[Route]: + """Get a route from the RouterConfig by name. + + :param name: The name of the route to get. + :type name: str + :return: The route if found, otherwise None. + :rtype: Optional[Route] + """ for route in self.routes: if route.name == name: return route @@ -273,6 +302,11 @@ class RouterConfig: return None def remove(self, name: str): + """Remove a route from the RouterConfig by name. + + :param name: The name of the route to remove. + :type name: str + """ if name not in [route.name for route in self.routes]: logger.error(f"Route `{name}` not found") else: @@ -280,6 +314,11 @@ class RouterConfig: logger.info(f"Removed route `{name}`") def get_hash(self) -> ConfigParameter: + """Get the hash of the RouterConfig. Used for syncing. + + :return: The hash of the RouterConfig. + :rtype: ConfigParameter + """ layer = self.to_dict() return ConfigParameter( field="sr_hash", @@ -288,6 +327,13 @@ class RouterConfig: def xq_reshape(xq: List[float] | np.ndarray) -> np.ndarray: + """Reshape the query vector to be a 2D numpy array. + + :param xq: The query vector. + :type xq: List[float] | np.ndarray + :return: The reshaped query vector. + :rtype: np.ndarray + """ # convert to numpy array if not already if not isinstance(xq, np.ndarray): xq = np.array(xq) @@ -302,6 +348,9 @@ def xq_reshape(xq: List[float] | np.ndarray) -> np.ndarray: class BaseRouter(BaseModel): + """Base class for all routers. + """ + encoder: DenseEncoder = Field(default_factory=OpenAIEncoder) sparse_encoder: Optional[SparseEncoder] = Field(default=None) index: BaseIndex = Field(default_factory=BaseIndex) @@ -327,6 +376,26 @@ class BaseRouter(BaseModel): aggregation: str = "mean", auto_sync: Optional[str] = None, ): + """Initialize a BaseRouter object. Expected to be used as a base class only, + not directly instantiated. + + :param encoder: The encoder to use. + :type encoder: Optional[DenseEncoder] + :param sparse_encoder: The sparse encoder to use. + :type sparse_encoder: Optional[SparseEncoder] + :param llm: The LLM to use. + :type llm: Optional[BaseLLM] + :param routes: The routes to use. + :type routes: Optional[List[Route]] + :param index: The index to use. + :type index: Optional[BaseIndex] + :param top_k: The number of routes to return. + :type top_k: int + :param aggregation: The aggregation method to use. + :type aggregation: str + :param auto_sync: The auto sync mode to use. + :type auto_sync: Optional[str] + """ routes = routes.copy() if routes else [] super().__init__( encoder=encoder, @@ -365,6 +434,13 @@ class BaseRouter(BaseModel): self._init_index_state() def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex: + """Get the index to use. + + :param index: The index to use. + :type index: Optional[BaseIndex] + :return: The index to use. + :rtype: BaseIndex + """ if index is None: logger.warning("No index provided. Using default LocalIndex.") index = LocalIndex() @@ -373,6 +449,13 @@ class BaseRouter(BaseModel): return index def _get_encoder(self, encoder: Optional[DenseEncoder]) -> DenseEncoder: + """Get the dense encoder to be used for creating dense vector embeddings. + + :param encoder: The encoder to use. + :type encoder: Optional[DenseEncoder] + :return: The encoder to use. + :rtype: DenseEncoder + """ if encoder is None: logger.warning("No encoder provided. Using default OpenAIEncoder.") encoder = OpenAIEncoder() @@ -383,6 +466,13 @@ class BaseRouter(BaseModel): def _get_sparse_encoder( self, sparse_encoder: Optional[SparseEncoder] ) -> Optional[SparseEncoder]: + """Get the sparse encoder to be used for creating sparse vector embeddings. + + :param sparse_encoder: The sparse encoder to use. + :type sparse_encoder: Optional[SparseEncoder] + :return: The sparse encoder to use. + :rtype: Optional[SparseEncoder] + """ if sparse_encoder is None: return None raise NotImplementedError( @@ -390,7 +480,8 @@ class BaseRouter(BaseModel): ) def _init_index_state(self): - """Initializes an index (where required) and runs auto_sync if active.""" + """Initializes an index (where required) and runs auto_sync if active. + """ # initialize index now, check if we need dimensions if self.index.dimensions is None: dims = len(self.encoder(["test"])[0]) @@ -429,6 +520,13 @@ class BaseRouter(BaseModel): ) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: + """Check for a matching route in the routes list. + + :param top_class: The top class to check for. + :type top_class: str + :return: The matching route if found, otherwise None. + :rtype: Optional[Route] + """ matching_route = next( (route for route in self.routes if route.name == top_class), None ) @@ -447,6 +545,19 @@ class BaseRouter(BaseModel): simulate_static: bool = False, route_filter: Optional[List[str]] = None, ) -> RouteChoice: + """Call the router to get a route choice. + + :param text: The text to route. + :type text: Optional[str] + :param vector: The vector to route. + :type vector: Optional[List[float] | np.ndarray] + :param simulate_static: Whether to simulate a static route. + :type simulate_static: bool + :param route_filter: The route filter to use. + :type route_filter: Optional[List[str]] + :return: The route choice. + :rtype: RouteChoice + """ if not self.index.is_ready(): raise ValueError("Index is not ready.") # if no vector provided, encode text to get vector @@ -505,6 +616,20 @@ class BaseRouter(BaseModel): simulate_static: bool = False, route_filter: Optional[List[str]] = None, ) -> RouteChoice: + """Asynchronously call the router to get a route choice. + + :param text: The text to route. + :type text: Optional[str] + :param vector: The vector to route. + :type vector: Optional[List[float] | np.ndarray] + :param simulate_static: Whether to simulate a static route (ie avoid dynamic route + LLM calls during fit or evaluate). + :type simulate_static: bool + :param route_filter: The route filter to use. + :type route_filter: Optional[List[str]] + :return: The route choice. + :rtype: RouteChoice + """ if not self.index.is_ready(): # TODO: need async version for qdrant raise ValueError("Index is not ready.") @@ -806,10 +931,15 @@ class BaseRouter(BaseModel): self.routes = new_routes def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool: + """Check if the route's score passes the specified threshold. + + :param scores: The scores to check. + :type scores: List[float] + :param route: The route to check. + :type route: Optional[Route] + :return: True if the route's score passes the threshold, otherwise False. + :rtype: bool """ - Check if the route's score passes the specified threshold. - """ - # TODO JB: do we need this? if route is None: return False threshold = ( @@ -828,6 +958,13 @@ class BaseRouter(BaseModel): @classmethod def from_json(cls, file_path: str): + """Load a RouterConfig from a JSON file. + + :param file_path: The path to the JSON file. + :type file_path: str + :return: The RouterConfig object. + :rtype: RouterConfig + """ config = RouterConfig.from_file(file_path) encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model if isinstance(encoder, DenseEncoder): @@ -837,6 +974,13 @@ class BaseRouter(BaseModel): @classmethod def from_yaml(cls, file_path: str): + """Load a RouterConfig from a YAML file. + + :param file_path: The path to the YAML file. + :type file_path: str + :return: The RouterConfig object. + :rtype: RouterConfig + """ config = RouterConfig.from_file(file_path) encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model if isinstance(encoder, DenseEncoder): @@ -846,6 +990,13 @@ class BaseRouter(BaseModel): @classmethod def from_config(cls, config: RouterConfig, index: Optional[BaseIndex] = None): + """Create a Router from a RouterConfig object. + + :param config: The RouterConfig object. + :type config: RouterConfig + :param index: The index to use. + :type index: Optional[BaseIndex] + """ encoder = AutoEncoder(type=config.encoder_type, name=config.encoder_name).model if isinstance(encoder, DenseEncoder): return cls(encoder=encoder, routes=config.routes, index=index) @@ -885,6 +1036,13 @@ class BaseRouter(BaseModel): The name must exist within the local SemanticRouter, if not a KeyError will be raised. + + :param name: The name of the route to update. + :type name: str + :param threshold: The threshold to update. + :type threshold: Optional[float] + :param utterances: The utterances to update. + :type utterances: Optional[List[str]] """ # TODO JB: should modify update to take a Route object current_local_hash = self._get_hash() @@ -958,7 +1116,10 @@ class BaseRouter(BaseModel): ) def _refresh_routes(self): - """Pulls out the latest routes from the index.""" + """Pulls out the latest routes from the index. + + Not yet implemented for BaseRouter. + """ raise NotImplementedError("This method has not yet been implemented.") route_mapping = {route.name: route for route in self.routes} index_routes = self.index.get_utterances() @@ -975,16 +1136,31 @@ class BaseRouter(BaseModel): self.routes.append(route) def _get_hash(self) -> ConfigParameter: + """Get the hash of the current routes. + + :return: The hash of the current routes. + :rtype: ConfigParameter + """ config = self.to_config() return config.get_hash() def _write_hash(self) -> ConfigParameter: + """Write the hash of the current routes to the index. + + :return: The hash of the current routes. + :rtype: ConfigParameter + """ config = self.to_config() hash_config = config.get_hash() self.index._write_config(config=hash_config) return hash_config async def _async_write_hash(self) -> ConfigParameter: + """Write the hash of the current routes to the index asynchronously. + + :return: The hash of the current routes. + :rtype: ConfigParameter + """ config = self.to_config() hash_config = config.get_hash() await self.index._async_write_config(config=hash_config) @@ -1042,6 +1218,12 @@ class BaseRouter(BaseModel): This diff tells us that the remote has "route2: utterance3" and "route2: utterance4", which do not exist locally. + + :param include_metadata: Whether to include metadata in the diff. + :type include_metadata: bool + :return: A list of strings showing the difference between the local and remote + utterances. + :rtype: List[str] """ # first we get remote and local utterances remote_utterances = self.index.get_utterances(include_metadata=include_metadata) @@ -1072,6 +1254,12 @@ class BaseRouter(BaseModel): This diff tells us that the remote has "route2: utterance3" and "route2: utterance4", which do not exist locally. + + :param include_metadata: Whether to include metadata in the diff. + :type include_metadata: bool + :return: A list of strings showing the difference between the local and remote + utterances. + :rtype: List[str] """ # first we get remote and local utterances remote_utterances = await self.index.aget_utterances( @@ -1087,6 +1275,15 @@ class BaseRouter(BaseModel): def _extract_routes_details( self, routes: List[Route], include_metadata: bool = False ) -> Tuple: + """Extract the routes details. + + :param routes: The routes to extract the details from. + :type routes: List[Route] + :param include_metadata: Whether to include metadata in the details. + :type include_metadata: bool + :return: A tuple of the route names, utterances, and function schemas. + """ + route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] function_schemas = [ @@ -1131,6 +1328,13 @@ class BaseRouter(BaseModel): raise NotImplementedError("This method should be implemented by subclasses.") def _set_aggregation_method(self, aggregation: str = "sum"): + """Set the aggregation method. + + :param aggregation: The aggregation method to use. + :type aggregation: str + :return: The aggregation method. + :rtype: Callable + """ # TODO is this really needed? if aggregation == "sum": return lambda x: sum(x) @@ -1205,6 +1409,13 @@ class BaseRouter(BaseModel): return "", [] def get(self, name: str) -> Optional[Route]: + """Get a route by name. + + :param name: The name of the route to get. + :type name: str + :return: The route. + :rtype: Optional[Route] + """ for route in self.routes: if route.name == name: return route @@ -1215,6 +1426,14 @@ class BaseRouter(BaseModel): def _semantic_classify_multiple_routes( self, query_results: List[Dict] ) -> List[Tuple[str, float]]: + """Classify the query results into multiple routes based on the highest total score. + + :param query_results: The query results to classify. Expected format is a list of + dictionaries with "route" and "score" keys. + :type query_results: List[Dict] + :return: A list of tuples containing the route name and its associated scores. + :rtype: List[Tuple[str, float]] + """ scores_by_class = self.group_scores_by_class(query_results) # Filter classes based on threshold and find max score for each @@ -1238,6 +1457,14 @@ class BaseRouter(BaseModel): def group_scores_by_class( self, query_results: List[Dict] ) -> Dict[str, List[float]]: + """Group the scores by class. + + :param query_results: The query results to group. Expected format is a list of + dictionaries with "route" and "score" keys. + :type query_results: List[Dict] + :return: A dictionary of route names and their associated scores. + :rtype: Dict[str, List[float]] + """ scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] @@ -1251,6 +1478,14 @@ class BaseRouter(BaseModel): async def async_group_scores_by_class( self, query_results: List[Dict] ) -> Dict[str, List[float]]: + """Group the scores by class asynchronously. + + :param query_results: The query results to group. Expected format is a list of + dictionaries with "route" and "score" keys. + :type query_results: List[Dict] + :return: A dictionary of route names and their associated scores. + :rtype: Dict[str, List[float]] + """ scores_by_class: Dict[str, List[float]] = {} for result in query_results: score = result["score"] @@ -1317,6 +1552,11 @@ class BaseRouter(BaseModel): logger.error(f"Route `{route_name}` not found") def to_config(self) -> RouterConfig: + """Convert the router to a RouterConfig object. + + :return: The RouterConfig object. + :rtype: RouterConfig + """ return RouterConfig( encoder_type=self.encoder.type, encoder_name=self.encoder.name, @@ -1324,14 +1564,29 @@ class BaseRouter(BaseModel): ) def to_json(self, file_path: str): + """Convert the router to a JSON file. + + :param file_path: The path to the JSON file. + :type file_path: str + """ config = self.to_config() config.to_file(file_path) def to_yaml(self, file_path: str): + """Convert the router to a YAML file. + + :param file_path: The path to the YAML file. + :type file_path: str + """ config = self.to_config() config.to_file(file_path) def get_thresholds(self) -> Dict[str, float]: + """Get the score thresholds for each route. + + :return: A dictionary of route names and their associated thresholds. + :rtype: Dict[str, float] + """ thresholds = { route.name: route.score_threshold or self.score_threshold or 0.0 for route in self.routes @@ -1346,7 +1601,20 @@ class BaseRouter(BaseModel): max_iter: int = 500, local_execution: bool = False, ): - original_index = self.index + """Fit the router to the data. Works best with a large number of examples for each + route and with many `None` utterances. + + :param X: The input data. + :type X: List[str] + :param y: The output data. + :type y: List[str] + :param batch_size: The batch size to use for fitting. + :type batch_size: int + :param max_iter: The maximum number of iterations to use for fitting. + :type max_iter: int + :param local_execution: Whether to execute the fitting locally. + :type local_execution: bool + """ if local_execution: # Switch to a local index for fitting from semantic_router.index.local import LocalIndex @@ -1401,8 +1669,16 @@ class BaseRouter(BaseModel): self.index = original_index def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float: - """ - Evaluate the accuracy of the route selection. + """Evaluate the accuracy of the route selection. + + :param X: The input data. + :type X: List[str] + :param y: The output data. + :type y: List[str] + :param batch_size: The batch size to use for evaluation. + :type batch_size: int + :return: The accuracy of the route selection. + :rtype: float """ Xq: List[List[float]] = [] for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"): @@ -1415,8 +1691,14 @@ class BaseRouter(BaseModel): def _vec_evaluate( self, Xq_d: Union[List[float], Any], y: List[str], **kwargs ) -> float: - """ - Evaluate the accuracy of the route selection. + """Evaluate the accuracy of the route selection. + + :param Xq_d: The input data. + :type Xq_d: Union[List[float], Any] + :param y: The output data. + :type y: List[str] + :return: The accuracy of the route selection. + :rtype: float """ correct = 0 for xq, target_route in zip(Xq_d, y): @@ -1428,6 +1710,11 @@ class BaseRouter(BaseModel): return accuracy def _get_route_names(self) -> List[str]: + """Get the names of the routes. + + :return: The names of the routes. + :rtype: List[str] + """ return [route.name for route in self.routes] @@ -1435,7 +1722,15 @@ def threshold_random_search( route_layer: BaseRouter, search_range: Union[int, float], ) -> Dict[str, float]: - """Performs a random search iteration given a route layer and a search range.""" + """Performs a random search iteration given a route layer and a search range. + + :param route_layer: The route layer to search. + :type route_layer: BaseRouter + :param search_range: The search range to use. + :type search_range: Union[int, float] + :return: A dictionary of route names and their associated thresholds. + :rtype: Dict[str, float] + """ # extract the route names routes = route_layer.get_thresholds() route_names = list(routes.keys()) diff --git a/semantic_router/routers/hybrid.py b/semantic_router/routers/hybrid.py index b7857783cd3ddbdcb8da1374256b9fbd3da6b337..f3e9107c2db0cfb6c51d637ef3f68e23b36c2094 100644 --- a/semantic_router/routers/hybrid.py +++ b/semantic_router/routers/hybrid.py @@ -38,6 +38,13 @@ class HybridRouter(BaseRouter): auto_sync: Optional[str] = None, alpha: float = 0.3, ): + """Initialize the HybridRouter. + + :param encoder: The dense encoder to use. + :type encoder: DenseEncoder + :param sparse_encoder: The sparse encoder to use. + :type sparse_encoder: Optional[SparseEncoder] + """ if index is None: logger.warning("No index provided. Using default HybridLocalIndex.") index = HybridLocalIndex() @@ -153,6 +160,13 @@ class HybridRouter(BaseRouter): self._write_hash() def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex: + """Get the index. + + :param index: The index to get. + :type index: Optional[BaseIndex] + :return: The index. + :rtype: BaseIndex + """ if index is None: logger.warning("No index provided. Using default HybridLocalIndex.") index = HybridLocalIndex() @@ -163,6 +177,13 @@ class HybridRouter(BaseRouter): def _get_sparse_encoder( self, sparse_encoder: Optional[SparseEncoder] ) -> Optional[SparseEncoder]: + """Get the sparse encoder. + + :param sparse_encoder: The sparse encoder to get. + :type sparse_encoder: Optional[SparseEncoder] + :return: The sparse encoder. + :rtype: Optional[SparseEncoder] + """ if sparse_encoder is None: logger.warning("No sparse_encoder provided. Using default BM25Encoder.") sparse_encoder = BM25Encoder() @@ -173,6 +194,11 @@ class HybridRouter(BaseRouter): def _encode(self, text: list[str]) -> tuple[np.ndarray, list[SparseEmbedding]]: """Given some text, generates dense and sparse embeddings, then scales them using the chosen alpha value. + + :param text: The text to encode. + :type text: list[str] + :return: A tuple of the dense and sparse embeddings. + :rtype: tuple[np.ndarray, list[SparseEmbedding]] """ if self.sparse_encoder is None: raise ValueError("self.sparse_encoder is not set.") @@ -193,6 +219,11 @@ class HybridRouter(BaseRouter): ) -> tuple[np.ndarray, list[SparseEmbedding]]: """Given some text, generates dense and sparse embeddings, then scales them using the chosen alpha value. + + :param text: The text to encode. + :type text: List[str] + :return: A tuple of the dense and sparse embeddings. + :rtype: tuple[np.ndarray, list[SparseEmbedding]] """ if self.sparse_encoder is None: raise ValueError("self.sparse_encoder is not set.") @@ -216,6 +247,19 @@ class HybridRouter(BaseRouter): route_filter: Optional[List[str]] = None, sparse_vector: dict[int, float] | SparseEmbedding | None = None, ) -> RouteChoice: + """Call the HybridRouter. + + :param text: The text to encode. + :type text: Optional[str] + :param vector: The vector to encode. + :type vector: Optional[List[float] | np.ndarray] + :param simulate_static: Whether to simulate a static route. + :type simulate_static: bool + :param route_filter: The route filter to use. + :type route_filter: Optional[List[str]] + :param sparse_vector: The sparse vector to use. + :type sparse_vector: dict[int, float] | SparseEmbedding | None + """ if not self.index.is_ready(): raise ValueError("Index is not ready.") potential_sparse_vector: List[SparseEmbedding] | None = None @@ -258,6 +302,13 @@ class HybridRouter(BaseRouter): def _convex_scaling( self, dense: np.ndarray, sparse: list[SparseEmbedding] ) -> tuple[np.ndarray, list[SparseEmbedding]]: + """Convex scaling of the dense and sparse vectors. + + :param dense: The dense vector to scale. + :type dense: np.ndarray + :param sparse: The sparse vector to scale. + :type sparse: list[SparseEmbedding] + """ # TODO: better way to do this? sparse_dicts = [sparse_vec.to_dict() for sparse_vec in sparse] # scale sparse and dense vecs @@ -279,6 +330,19 @@ class HybridRouter(BaseRouter): max_iter: int = 500, local_execution: bool = False, ): + """Fit the HybridRouter. + + :param X: The input data. + :type X: List[str] + :param y: The output data. + :type y: List[str] + :param batch_size: The batch size to use for fitting. + :type batch_size: int + :param max_iter: The maximum number of iterations to use for fitting. + :type max_iter: int + :param local_execution: Whether to execute the fitting locally. + :type local_execution: bool + """ original_index = self.index if self.sparse_encoder is None: raise ValueError("Sparse encoder is not set.") @@ -343,8 +407,16 @@ class HybridRouter(BaseRouter): self.index = original_index def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float: - """ - Evaluate the accuracy of the route selection. + """Evaluate the accuracy of the route selection. + + :param X: The input data. + :type X: List[str] + :param y: The output data. + :type y: List[str] + :param batch_size: The batch size to use for evaluation. + :type batch_size: int + :return: The accuracy of the route selection. + :rtype: float """ if self.sparse_encoder is None: raise ValueError("Sparse encoder is not set.") @@ -365,8 +437,16 @@ class HybridRouter(BaseRouter): Xq_s: list[SparseEmbedding], y: List[str], ) -> float: - """ - Evaluate the accuracy of the route selection. + """Evaluate the accuracy of the route selection. + + :param Xq_d: The dense vectors to evaluate. + :type Xq_d: Union[List[float], Any] + :param Xq_s: The sparse vectors to evaluate. + :type Xq_s: list[SparseEmbedding] + :param y: The output data. + :type y: List[str] + :return: The accuracy of the route selection. + :rtype: float """ correct = 0 for xq_d, xq_s, target_route in zip(Xq_d, Xq_s, y): diff --git a/semantic_router/routers/semantic.py b/semantic_router/routers/semantic.py index a468e0de7669ad7e6efdc44ad2fe4bf93608781f..19d9903c5cc937cb3f5ada7d3d8405dbf8d4454e 100644 --- a/semantic_router/routers/semantic.py +++ b/semantic_router/routers/semantic.py @@ -11,6 +11,8 @@ from semantic_router.utils.logger import logger class SemanticRouter(BaseRouter): + """A router that uses a dense encoder to encode routes and utterances. + """ def __init__( self, encoder: Optional[DenseEncoder] = None, @@ -34,13 +36,25 @@ class SemanticRouter(BaseRouter): ) def _encode(self, text: list[str]) -> Any: - """Given some text, encode it.""" + """Given some text, encode it. + + :param text: The text to encode. + :type text: list[str] + :return: The encoded text. + :rtype: Any + """ # create query vector xq = np.array(self.encoder(text)) return xq async def _async_encode(self, text: list[str]) -> Any: - """Given some text, encode it.""" + """Given some text, encode it. + + :param text: The text to encode. + :type text: list[str] + :return: The encoded text. + :rtype: Any + """ # create query vector xq = np.array(await self.encoder.acall(docs=text)) return xq