From 7a5dfc6c56ac365bbb126d3b6c68b374357bd5f2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= <andrped94@gmail.com>
Date: Wed, 13 Mar 2024 18:35:42 +0100
Subject: [PATCH] Add support for setting top_k for HybridRouteLayer through
 class API

---
 semantic_router/hybrid_layer.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py
index 71598a15..4a9a368d 100644
--- a/semantic_router/hybrid_layer.py
+++ b/semantic_router/hybrid_layer.py
@@ -24,6 +24,7 @@ class HybridRouteLayer:
         sparse_encoder: Optional[BM25Encoder] = None,
         routes: List[Route] = [],
         alpha: float = 0.3,
+        top_k: int = 5,
     ):
         self.encoder = encoder
         self.score_threshold = self.encoder.score_threshold
@@ -35,6 +36,7 @@ class HybridRouteLayer:
             self.sparse_encoder = sparse_encoder
 
         self.alpha = alpha
+        self.top_k = top_k
         self.routes = routes
         if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
             self.sparse_encoder, "fit"
@@ -48,7 +50,7 @@ class HybridRouteLayer:
             self._add_routes(routes)
 
     def __call__(self, text: str) -> Optional[str]:
-        results = self._query(text)
+        results = self._query(text, self.top_k)
         top_class, top_class_scores = self._semantic_classify(results)
         passed = self._pass_threshold(top_class_scores, self.score_threshold)
         if passed:
-- 
GitLab