Skip to content
Snippets Groups Projects
Commit 7a5dfc6c authored by André Pedersen's avatar André Pedersen
Browse files

Add support for setting top_k for HybridRouteLayer through class API

parent 1e753d93
No related branches found
No related tags found
No related merge requests found
...@@ -24,6 +24,7 @@ class HybridRouteLayer: ...@@ -24,6 +24,7 @@ class HybridRouteLayer:
sparse_encoder: Optional[BM25Encoder] = None, sparse_encoder: Optional[BM25Encoder] = None,
routes: List[Route] = [], routes: List[Route] = [],
alpha: float = 0.3, alpha: float = 0.3,
top_k: int = 5,
): ):
self.encoder = encoder self.encoder = encoder
self.score_threshold = self.encoder.score_threshold self.score_threshold = self.encoder.score_threshold
...@@ -35,6 +36,7 @@ class HybridRouteLayer: ...@@ -35,6 +36,7 @@ class HybridRouteLayer:
self.sparse_encoder = sparse_encoder self.sparse_encoder = sparse_encoder
self.alpha = alpha self.alpha = alpha
self.top_k = top_k
self.routes = routes self.routes = routes
if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr( if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
self.sparse_encoder, "fit" self.sparse_encoder, "fit"
...@@ -48,7 +50,7 @@ class HybridRouteLayer: ...@@ -48,7 +50,7 @@ class HybridRouteLayer:
self._add_routes(routes) self._add_routes(routes)
def __call__(self, text: str) -> Optional[str]: def __call__(self, text: str) -> Optional[str]:
results = self._query(text) results = self._query(text, self.top_k)
top_class, top_class_scores = self._semantic_classify(results) top_class, top_class_scores = self._semantic_classify(results)
passed = self._pass_threshold(top_class_scores, self.score_threshold) passed = self._pass_threshold(top_class_scores, self.score_threshold)
if passed: if passed:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment