Skip to content
Snippets Groups Projects
Commit 4394759e authored by “Daniel Griffiths”'s avatar “Daniel Griffiths”
Browse files

update sparse_encoder if statements to type

update sparse_encoder if statements to type rather than name
parent e88a0c25
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ from semantic_router.encoders import ( ...@@ -7,6 +7,7 @@ from semantic_router.encoders import (
BM25Encoder, BM25Encoder,
CohereEncoder, CohereEncoder,
OpenAIEncoder, OpenAIEncoder,
TfidfEncoder
) )
from semantic_router.schema import Route from semantic_router.schema import Route
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
...@@ -33,7 +34,7 @@ class HybridRouteLayer: ...@@ -33,7 +34,7 @@ class HybridRouteLayer:
else: else:
self.score_threshold = 0.82 self.score_threshold = 0.82
# if routes list has been passed, we initialize index now # if routes list has been passed, we initialize index now
if self.sparse_encoder.name == 'tfidf': if isinstance(sparse_encoder, TfidfEncoder):
self.sparse_encoder.fit(routes) self.sparse_encoder.fit(routes)
if routes: if routes:
# initialize index now # initialize index now
...@@ -50,11 +51,11 @@ class HybridRouteLayer: ...@@ -50,11 +51,11 @@ class HybridRouteLayer:
return None return None
def add(self, route: Route): def add(self, route: Route):
if self.sparse_encoder.name == 'tfidf': if isinstance(self.sparse_encoder, TfidfEncoder):
self.sparse_encoder.fit(self.routes + [route]) self.sparse_encoder.fit(self.routes + [route])
self.sparse_index = None self.sparse_index = None
for r in self.routes: for r in self.routes:
self.calculate_sparse_embeds(r) self.compute_and_store_sparse_embeddings(r)
self.routes.append(route) self.routes.append(route)
self._add_route(route=route) self._add_route(route=route)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment