From 51d2bc9a41893b4e95679c6a6e5311844313c9bf Mon Sep 17 00:00:00 2001 From: Ismail Ashraq <ismailashraq@Ismails-MacBook-Pro.local> Date: Tue, 16 Jul 2024 14:15:02 +0800 Subject: [PATCH] aquery method for local index --- semantic_router/index/local.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py index 7e32f3a8..1116ffe4 100644 --- a/semantic_router/index/local.py +++ b/semantic_router/index/local.py @@ -98,6 +98,35 @@ class LocalIndex(BaseIndex): scores, idx = top_scores(sim, top_k) route_names = [self.routes[i] for i in idx] return scores, route_names + + async def aquery( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: + """ + Search the index for the query and return top_k results. + """ + if self.index is None or self.routes is None: + raise ValueError("Index or routes are not populated.") + if route_filter is not None: + filtered_index = [] + filtered_routes = [] + for route, vec in zip(self.routes, self.index): + if route in route_filter: + filtered_index.append(vec) + filtered_routes.append(route) + if not filtered_routes: + raise ValueError("No routes found matching the filter criteria.") + sim = similarity_matrix(vector, np.array(filtered_index)) + scores, idx = top_scores(sim, top_k) + route_names = [filtered_routes[i] for i in idx] + else: + sim = similarity_matrix(vector, self.index) + scores, idx = top_scores(sim, top_k) + route_names = [self.routes[i] for i in idx] + return scores, route_names def delete(self, route_name: str): """ -- GitLab