diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index 71598a15aa1235dd2a4cf018de5834287c5eed5d..4a9a368d77600fd8205e50e7f3aefbe2792bf7b9 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -24,6 +24,7 @@ class HybridRouteLayer: sparse_encoder: Optional[BM25Encoder] = None, routes: List[Route] = [], alpha: float = 0.3, + top_k: int = 5, ): self.encoder = encoder self.score_threshold = self.encoder.score_threshold @@ -35,6 +36,7 @@ class HybridRouteLayer: self.sparse_encoder = sparse_encoder self.alpha = alpha + self.top_k = top_k self.routes = routes if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( self.sparse_encoder, "fit" @@ -48,7 +50,7 @@ class HybridRouteLayer: self._add_routes(routes) def __call__(self, text: str) -> Optional[str]: - results = self._query(text) + results = self._query(text, self.top_k) top_class, top_class_scores = self._semantic_classify(results) passed = self._pass_threshold(top_class_scores, self.score_threshold) if passed: