From 51d2bc9a41893b4e95679c6a6e5311844313c9bf Mon Sep 17 00:00:00 2001
From: Ismail Ashraq <ismailashraq@Ismails-MacBook-Pro.local>
Date: Tue, 16 Jul 2024 14:15:02 +0800
Subject: [PATCH] aquery method for local index

---
 semantic_router/index/local.py | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 7e32f3a8..1116ffe4 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -98,6 +98,35 @@ class LocalIndex(BaseIndex):
             scores, idx = top_scores(sim, top_k)
             route_names = [self.routes[i] for i in idx]
         return scores, route_names
+    
+    async def aquery(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
+        """
+        Search the index for the query and return top_k results.
+        """
+        if self.index is None or self.routes is None:
+            raise ValueError("Index or routes are not populated.")
+        if route_filter is not None:
+            filtered_index = []
+            filtered_routes = []
+            for route, vec in zip(self.routes, self.index):
+                if route in route_filter:
+                    filtered_index.append(vec)
+                    filtered_routes.append(route)
+            if not filtered_routes:
+                raise ValueError("No routes found matching the filter criteria.")
+            sim = similarity_matrix(vector, np.array(filtered_index))
+            scores, idx = top_scores(sim, top_k)
+            route_names = [filtered_routes[i] for i in idx]
+        else:
+            sim = similarity_matrix(vector, self.index)
+            scores, idx = top_scores(sim, top_k)
+            route_names = [self.routes[i] for i in idx]
+        return scores, route_names
 
     def delete(self, route_name: str):
         """
-- 
GitLab