Skip to content
Snippets Groups Projects
Commit 7a5dfc6c authored by André Pedersen's avatar André Pedersen
Browse files

Add support for setting top_k for HybridRouteLayer through class API

parent 1e753d93
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@ class HybridRouteLayer:
sparse_encoder: Optional[BM25Encoder] = None,
routes: List[Route] = [],
alpha: float = 0.3,
top_k: int = 5,
):
self.encoder = encoder
self.score_threshold = self.encoder.score_threshold
......@@ -35,6 +36,7 @@ class HybridRouteLayer:
self.sparse_encoder = sparse_encoder
self.alpha = alpha
self.top_k = top_k
self.routes = routes
if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
self.sparse_encoder, "fit"
......@@ -48,7 +50,7 @@ class HybridRouteLayer:
self._add_routes(routes)
def __call__(self, text: str) -> Optional[str]:
results = self._query(text)
results = self._query(text, self.top_k)
top_class, top_class_scores = self._semantic_classify(results)
passed = self._pass_threshold(top_class_scores, self.score_threshold)
if passed:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment