From 7a5dfc6c56ac365bbb126d3b6c68b374357bd5f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= <andrped94@gmail.com> Date: Wed, 13 Mar 2024 18:35:42 +0100 Subject: [PATCH] Add support for setting top_k for HybridRouteLayer through class API --- semantic_router/hybrid_layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 71598a15..4a9a368d 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -24,6 +24,7 @@ class HybridRouteLayer: sparse_encoder: Optional[BM25Encoder] = None, routes: List[Route] = [], alpha: float = 0.3, + top_k: int = 5, ): self.encoder = encoder self.score_threshold = self.encoder.score_threshold @@ -35,6 +36,7 @@ class HybridRouteLayer: self.sparse_encoder = sparse_encoder self.alpha = alpha + self.top_k = top_k self.routes = routes if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( self.sparse_encoder, "fit" @@ -48,7 +50,7 @@ class HybridRouteLayer: self._add_routes(routes) def __call__(self, text: str) -> Optional[str]: - results = self._query(text) + results = self._query(text, self.top_k) top_class, top_class_scores = self._semantic_classify(results) passed = self._pass_threshold(top_class_scores, self.score_threshold) if passed: -- GitLab