diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 3d3910834f0912307ab69d1b21e5d1ef5bb38d6f..d07045e70692ffc2be2ddadcae0408f05b1a0cea 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -99,7 +99,11 @@ class BaseIndex(BaseModel): raise NotImplementedError("This method should be implemented by subclasses.") def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int + self, + local_route_names: List[str], + local_utterances: List[str], + dimensions: int, + local_function_schemas: List[str] = None, # type: ignore ): """ Synchronize the local index with the remote index based on the specified mode. diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 802455cb77fbf1d18ac07929508b190684d984f4..65446f460fb3d6664e30f88616d8a900fccaca56 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -49,7 +49,11 @@ class LocalIndex(BaseIndex): logger.warning("Sync remove is not implemented for LocalIndex.") def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int + self, + local_route_names: List[str], + local_utterances: List[str], + dimensions: int, + local_function_schemas: List[str] = None, # type: ignore ): if self.sync is not None: logger.error("Sync remove is not implemented for LocalIndex.") diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index da11ec6ea4df8ee762fb2d84b1a66a0de504d0c2..7686993cac4887006b115b856514bdefa36db8af 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -207,7 +207,11 @@ class PineconeIndex(BaseIndex): self.host = index_stats["host"] if index_stats else None def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int + self, + local_route_names: List[str], + local_utterances: List[str], + dimensions: int, + local_function_schemas: List[str] = None, # type: ignore ): if self.index is None: self.dimensions = self.dimensions or dimensions diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index 0fc6aa52c1ee8be1d769a8463ff68a6e31e3bfe5..5801c6debc0427e2c6da2fa4b17929541a5f1a1e 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -165,7 +165,11 @@ class QdrantIndex(BaseIndex): logger.error("Sync remove is not implemented for QdrantIndex.") def _sync_index( - self, local_route_names: List[str], local_utterances: List[str], dimensions: int + self, + local_route_names: List[str], + local_utterances: List[str], + dimensions: int, + local_function_schemas: List[str] = None, # type: ignore ): if self.sync is not None: logger.error("Sync remove is not implemented for QdrantIndex.") diff --git a/semantic_router/layer.py b/semantic_router/layer.py index bae67544940c9a30333ae7331a0abcf354873ab8..c16700bfb10981ced0e54e27ee2805ac028ecac3 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -180,7 +180,7 @@ class RouteLayer: self, encoder: Optional[BaseEncoder] = None, llm: Optional[BaseLLM] = None, - routes: Optional[List[Route]] = None, + routes: List[Route] = [], index: Optional[BaseIndex] = None, # type: ignore top_k: int = 5, aggregation: str = "sum", @@ -195,7 +195,7 @@ class RouteLayer: else: self.encoder = encoder self.llm = llm - self.routes: List[Route] = routes if routes is not None else [] + self.routes = routes if self.encoder.score_threshold is None: raise ValueError( "No score threshold provided for encoder. Please set the score threshold " @@ -216,15 +216,13 @@ class RouteLayer: for route in self.routes: if route.score_threshold is None: route.score_threshold = self.score_threshold + + if self.routes: + self._add_routes(routes=self.routes) + # if routes list has been passed, we initialize index now if self.index.sync: - # initialize index now - if len(self.routes) > 0: - self._add_and_sync_routes(routes=self.routes) - else: - self._add_and_sync_routes(routes=[]) - elif len(self.routes) > 0: - self._add_routes(routes=self.routes) + self._add_and_sync_routes(routes=self.routes) def check_for_matching_routes(self, top_class: str) -> Optional[Route]: matching_routes = [route for route in self.routes if route.name == top_class] @@ -482,32 +480,52 @@ class RouteLayer: def _add_routes(self, routes: List[Route]): # create embeddings for all routes - route_names, all_utterances, function_schemas = self._extract_routes_details( - routes - ) - embedded_utterances = self.encoder(all_utterances) + # route_names, all_utterances, function_schemas = self._extract_routes_details( + # routes + # ) + # embedded_utterances = self.encoder(all_utterances) # create route array # add everything to the index - self.index.add( - embeddings=embedded_utterances, - routes=route_names, - utterances=all_utterances, - function_schemas=function_schemas, - ) + if routes: + for route in routes: + logger.info(f"Adding `{route.name}` route") + embeddings = self.encoder(route.utterances) + if route.score_threshold is None: + route.score_threshold = self.score_threshold + + try: + self.index.add( + embeddings=embeddings, + routes=[route.name] * len(route.utterances), + utterances=route.utterances, + function_schemas=( + route.function_schemas * len(route.utterances) + if route.function_schemas + else [""] * len(route.utterances) # type: ignore + ), + ) + except Exception as e: + logger.error(f"index error: {e}") + raise Exception(f"index error: {e}") from e def _add_and_sync_routes(self, routes: List[Route]): # create embeddings for all routes and sync at startup with remote ones based on sync setting local_route_names, local_utterances, local_function_schemas = ( self._extract_routes_details(routes) ) + routes_to_add, routes_to_delete, layer_routes_dict = self.index._sync_index( local_route_names=local_route_names, local_utterances=local_utterances, + local_function_schemas=local_function_schemas, dimensions=len(self.encoder(["dummy"])[0]), ) layer_routes = [ - Route(name=route, utterances=layer_routes_dict[route]) + Route( + name=route, + utterances=layer_routes_dict[route], + ) for route in layer_routes_dict.keys() ] @@ -536,10 +554,9 @@ class RouteLayer: route_names = [route.name for route in routes for _ in route.utterances] utterances = [utterance for route in routes for utterance in route.utterances] function_schemas = [ - function_schema if function_schema is not None else "" + route.function_schemas if route.function_schemas is not None else "" for route in routes - if route.function_schemas is not None - for function_schema in route.function_schemas + for _ in route.utterances ] return route_names, utterances, function_schemas