From f517a558b7e4599e04e1ae668d223593d29af101 Mon Sep 17 00:00:00 2001
From: zahid-syed <zahid.s2618@gmail.com>
Date: Tue, 19 Mar 2024 16:14:41 -0400
Subject: [PATCH] Initial commit to review changes

---
 semantic_router/index/base.py     |  7 ++++++-
 semantic_router/index/local.py    | 30 ++++++++++++++++++++++++------
 semantic_router/index/pinecone.py | 14 +++++++++++++-
 semantic_router/layer.py          | 18 +++++++++++++-----
 4 files changed, 56 insertions(+), 13 deletions(-)

diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py
index 5271b897..f8a965df 100644
--- a/semantic_router/index/base.py
+++ b/semantic_router/index/base.py
@@ -43,7 +43,12 @@ class BaseIndex(BaseModel):
         """
         raise NotImplementedError("This method should be implemented by subclasses.")
 
-    def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
+    def query(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
         """
         Search the index for the query_vector and return top_k results.
         This method should be implemented by subclasses.
diff --git a/semantic_router/index/local.py b/semantic_router/index/local.py
index 81da71d1..4bf212dc 100644
--- a/semantic_router/index/local.py
+++ b/semantic_router/index/local.py
@@ -56,17 +56,35 @@ class LocalIndex(BaseIndex):
             "vectors": self.index.shape[0] if self.index is not None else 0,
         }
 
-    def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
+    def query(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
         """
         Search the index for the query and return top_k results.
         """
         if self.index is None or self.routes is None:
             raise ValueError("Index or routes are not populated.")
-        sim = similarity_matrix(vector, self.index)
-        # extract the index values of top scoring vectors
-        scores, idx = top_scores(sim, top_k)
-        # get routes from index values
-        route_names = self.routes[idx].copy()
+        if route_filter is not None:
+            print(f"Filtering routes with filter: {route_filter}")
+            filtered_index = []
+            filtered_routes = []
+            for route, vec in zip(self.routes, self.index):
+                if route in route_filter:
+                    filtered_index.append(vec)
+                    filtered_routes.append(route)
+            if not filtered_routes:
+                raise ValueError("No routes found matching the filter criteria.")
+            sim = similarity_matrix(vector, np.array(filtered_index))
+            scores, idx = top_scores(sim, top_k)
+            route_names = [filtered_routes[i] for i in idx]
+        else:
+            sim = similarity_matrix(vector, self.index)
+            scores, idx = top_scores(sim, top_k)
+            route_names = [self.routes[i] for i in idx]
+        print(f"Routes considered for similarity calculation: {route_names}")
         return scores, route_names
 
     def delete(self, route_name: str):
diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py
index 321dee30..500b78e9 100644
--- a/semantic_router/index/pinecone.py
+++ b/semantic_router/index/pinecone.py
@@ -219,17 +219,29 @@ class PineconeIndex(BaseIndex):
         else:
             raise ValueError("Index is None, cannot describe index stats.")
 
-    def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
+    def query(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
         if self.index is None:
             raise ValueError("Index is not populated.")
         query_vector_list = vector.tolist()
+        if route_filter is not None:
+            print(f"Filtering routes with filter: {route_filter}")
+            filter_query = {"sr_route": {"$in": route_filter}}
+        else:
+            filter_query = None
         results = self.index.query(
             vector=[query_vector_list],
             top_k=top_k,
+            filter=filter_query,
             include_metadata=True,
         )
         scores = [result["score"] for result in results["matches"]]
         route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
+        print(f"Routes considered for similarity calculation: {route_names}")
         return np.array(scores), route_names
 
     def delete_index(self):
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 221de2be..2d5eee0c 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -232,6 +232,7 @@ class RouteLayer:
         text: Optional[str] = None,
         vector: Optional[List[float]] = None,
         simulate_static: bool = False,
+        route_filter: Optional[List[str]] = None,
     ) -> RouteChoice:
         # if no vector provided, encode text to get vector
         if vector is None:
@@ -239,7 +240,8 @@ class RouteLayer:
                 raise ValueError("Either text or vector must be provided")
             vector = self._encode(text=text)
 
-        route, top_class_scores = self._retrieve_top_route(vector)
+        route, top_class_scores = self._retrieve_top_route(vector, route_filter)
+        print(f"Selected route: {route.name if route else 'None'}")
         passed = self._check_threshold(top_class_scores, route)
 
         if passed and route is not None and not simulate_static:
@@ -271,14 +273,16 @@ class RouteLayer:
             return RouteChoice()
 
     def _retrieve_top_route(
-        self, vector: List[float]
+        self, vector: List[float], route_filter: Optional[List[str]] = None
     ) -> Tuple[Optional[Route], List[float]]:
         """
         Retrieve the top matching route based on the given vector.
         Returns a tuple of the route (if any) and the scores of the top class.
         """
         # get relevant results (scores and routes)
-        results = self._retrieve(xq=np.array(vector), top_k=self.top_k)
+        results = self._retrieve(
+            xq=np.array(vector), top_k=self.top_k, route_filter=route_filter
+        )
         # decide most relevant routes
         top_class, top_class_scores = self._semantic_classify(results)
         # TODO do we need this check?
@@ -397,10 +401,14 @@ class RouteLayer:
         xq = np.squeeze(xq)  # Reduce to 1d array.
         return xq
 
-    def _retrieve(self, xq: Any, top_k: int = 5) -> List[dict]:
+    def _retrieve(
+        self, xq: Any, top_k: int = 5, route_filter: Optional[List[str]] = None
+    ) -> List[dict]:
         """Given a query vector, retrieve the top_k most similar records."""
         # get scores and routes
-        scores, routes = self.index.query(vector=xq, top_k=top_k)
+        scores, routes = self.index.query(
+            vector=xq, top_k=top_k, route_filter=route_filter
+        )
         return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
 
     def _set_aggregation_method(self, aggregation: str = "sum"):
-- 
GitLab