From 7dd6eabd10e4f5c2de5383d8dce260bc7ef6bb5c Mon Sep 17 00:00:00 2001
From: zahid-syed <zahid.s2618@gmail.com>
Date: Wed, 20 Mar 2024 17:31:32 -0400
Subject: [PATCH] added filter to qdrant index

---
 semantic_router/index/qdrant.py | 27 ++++++++++++++++++++++++++-
 tests/unit/test_layer.py        |  5 +++--
 2 files changed, 29 insertions(+), 3 deletions(-)

diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py
index e1339afc..10c97019 100644
--- a/semantic_router/index/qdrant.py
+++ b/semantic_router/index/qdrant.py
@@ -217,10 +217,35 @@ class QdrantIndex(BaseIndex):
             "vectors": collection_info.points_count,
         }
 
-    def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]:
+    def query(
+        self,
+        vector: np.ndarray,
+        top_k: int = 5,
+        route_filter: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, List[str]]:
+        from qdrant_client import models
+
         results = self.client.search(
             self.index_name, query_vector=vector, limit=top_k, with_payload=True
         )
+        filter = None
+        if route_filter is not None:
+            filter = models.Filter(
+                must=[
+                    models.FieldCondition(
+                        key=SR_ROUTE_PAYLOAD_KEY,
+                        values=route_filter,
+                    )
+                ]
+            )
+
+        results = self.client.search(
+            self.index_name,
+            query_vector=vector,
+            limit=top_k,
+            with_payload=True,
+            filter=filter,
+        )
         scores = [result.score for result in results]
         route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
         return np.array(scores), route_names
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index cd32b35e..bd3ceaec 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -244,9 +244,10 @@ class TestRouteLayer:
         query_result = route_layer(text="Hello").name
         assert query_result in ["Route 1", "Route 2"]
 
-
     def test_query_filter(self, openai_encoder, routes, index_cls):
-        route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index_cls())
+        route_layer = RouteLayer(
+            encoder=openai_encoder, routes=routes, index=index_cls()
+        )
         query_result = route_layer(text="Hello", route_filter=["Route 1"]).name
         assert query_result in ["Route 1"]
 
-- 
GitLab