From 362e98ef89959f978c2653753e52722de09ddd29 Mon Sep 17 00:00:00 2001 From: jamescalam <james.briggs@hotmail.com> Date: Thu, 28 Nov 2024 16:11:41 +0100 Subject: [PATCH] fix: router test bug fixes --- semantic_router/index/local.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 83cc4f51..8f674c62 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -42,8 +42,23 @@ class LocalIndex(BaseIndex): self.routes = np.concatenate([self.routes, routes_arr]) self.utterances = np.concatenate([self.utterances, utterances_arr]) - def _remove_and_sync(self, routes_to_delete: dict): - logger.warning(f"Sync remove is not implemented for {self.__class__.__name__}.") + def _remove_and_sync(self, routes_to_delete: dict) -> np.ndarray: + if self.index is None or self.routes is None or self.utterances is None: + raise ValueError("Index, routes, or utterances are not populated.") + # TODO JB: implement routes and utterances as a numpy array + route_utterances = np.array([self.routes, self.utterances]).T + # initialize our mask with all true values (ie keep all) + mask = np.ones(len(route_utterances), dtype=bool) + for route, utterances in routes_to_delete.items(): + # TODO JB: we should be able to vectorize this? + for utterance in utterances: + mask &= ~((route_utterances[:, 0] == route) & (route_utterances[:, 1] == utterance)) + # apply the mask to index, routes, and utterances + self.index = self.index[mask] + self.routes = self.routes[mask] + self.utterances = self.utterances[mask] + # return what was removed + return route_utterances[~mask] def get_utterances(self) -> List[Utterance]: """ -- GitLab