diff --git a/semantic_router/index/qdrant.py b/semantic_router/index/qdrant.py index e1339afc033aae4bf9701675583a03aaa4acf225..10c97019443cdff1755df7772162c6f47a8103d4 100644 --- a/semantic_router/index/qdrant.py +++ b/semantic_router/index/qdrant.py @@ -217,10 +217,35 @@ class QdrantIndex(BaseIndex): "vectors": collection_info.points_count, } - def query(self, vector: np.ndarray, top_k: int = 5) -> Tuple[np.ndarray, List[str]]: + def query( + self, + vector: np.ndarray, + top_k: int = 5, + route_filter: Optional[List[str]] = None, + ) -> Tuple[np.ndarray, List[str]]: + from qdrant_client import models + results = self.client.search( self.index_name, query_vector=vector, limit=top_k, with_payload=True ) + filter = None + if route_filter is not None: + filter = models.Filter( + must=[ + models.FieldCondition( + key=SR_ROUTE_PAYLOAD_KEY, + values=route_filter, + ) + ] + ) + + results = self.client.search( + self.index_name, + query_vector=vector, + limit=top_k, + with_payload=True, + filter=filter, + ) scores = [result.score for result in results] route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results] return np.array(scores), route_names diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index cd32b35e08192b7d1ce5914c155711cccfe32643..bd3ceaecc74da5ff1625a719050e3e58b20354b1 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -244,9 +244,10 @@ class TestRouteLayer: query_result = route_layer(text="Hello").name assert query_result in ["Route 1", "Route 2"] - def test_query_filter(self, openai_encoder, routes, index_cls): - route_layer = RouteLayer(encoder=openai_encoder, routes=routes, index=index_cls()) + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=index_cls() + ) query_result = route_layer(text="Hello", route_filter=["Route 1"]).name assert query_result in ["Route 1"]