From 4394759e4b96969544a723a14f169e1d8e82e8f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= <Danielgriffiths1790@gmail.com> Date: Mon, 18 Dec 2023 19:08:40 +0000 Subject: [PATCH] update sparse_encoder if statements to type update sparse_encoder if statements to type rather than name --- semantic_router/hybrid_layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index a68472d3..33a3269f 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -7,6 +7,7 @@ from semantic_router.encoders import ( BM25Encoder, CohereEncoder, OpenAIEncoder, + TfidfEncoder ) from semantic_router.schema import Route from semantic_router.utils.logger import logger @@ -33,7 +34,7 @@ class HybridRouteLayer: else: self.score_threshold = 0.82 # if routes list has been passed, we initialize index now - if self.sparse_encoder.name == 'tfidf': + if isinstance(sparse_encoder, TfidfEncoder): self.sparse_encoder.fit(routes) if routes: # initialize index now @@ -50,11 +51,11 @@ class HybridRouteLayer: return None def add(self, route: Route): - if self.sparse_encoder.name == 'tfidf': + if isinstance(self.sparse_encoder, TfidfEncoder): self.sparse_encoder.fit(self.routes + [route]) self.sparse_index = None for r in self.routes: - self.calculate_sparse_embeds(r) + self.compute_and_store_sparse_embeddings(r) self.routes.append(route) self._add_route(route=route) -- GitLab