From 7dd6eabd10e4f5c2de5383d8dce260bc7ef6bb5c Mon Sep 17 00:00:00 2001 From: zahid-syed <zahid.s2618@gmail.com> Date: Wed, 20 Mar 2024 17:31:32 -0400 Subject: [PATCH] added filter to qdrant index --- semantic_router/index/qdrant.py | 27 ++++++++++++++++++++++++++- tests/unit/test_layer.py | 5 +++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index e1339afc..10c97019 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -217,10 +217,35 @@ class QdrantIndex(BaseIndex): "vectors": collection_info.points_count, } - def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]: + def query( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: + from qdrant_client import models + results = self.client.search( self.index_name, query_vector=vector, limit=top_k, with_payload=True ) + filter = None + if route_filter is not None: + filter = models.Filter( + must=[ + models.FieldCondition( + key=SR_ROUTE_PAYLOAD_KEY, + values=route_filter, + ) + ] + ) + + results = self.client.search( + self.index_name, + query_vector=vector, + limit=top_k, + with_payload=True, + filter=filter, + ) scores = [result.score for result in results] route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results] return np.array(scores), route_names diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index cd32b35e..bd3ceaec 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -244,9 +244,10 @@ class TestRouteLayer: query_result = route_layer(text="Hello").name assert query_result in ["Route 1", "Route 2"] - def test_query_filter(self, openai_encoder, routes, index_cls): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index_cls()) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) query_result = route_layer(text="Hello", route_filter=["Route 1"]).name assert query_result in ["Route 1"] -- GitLab