Skip to content
Snippets Groups Projects
Commit d5512014 authored by Ismail Ashraq's avatar Ismail Ashraq
Browse files

Update hybrid layer to use score_threshold from encoders

parent 7862efdb
No related branches found
No related tags found
No related merge requests found
......@@ -19,7 +19,11 @@ class BM25Encoder(BaseEncoder):
"You can install it with: `pip install semantic-router[hybrid]`"
)
logger.info("Downloading and initializing BM25 model parameters.")
self.model = encoder.default()
# self.model = encoder.default()
self.model = encoder()
self.model.fit(
corpus=["test test", "this is another message", "hello how are you"]
)
params = self.model.get_params()
doc_freq = params["doc_freq"]
......
......@@ -4,8 +4,6 @@ from numpy.linalg import norm
from semantic_router.encoders import (
BaseEncoder,
BM25Encoder,
CohereEncoder,
OpenAIEncoder,
)
from semantic_router.route import Route
from semantic_router.utils.logger import logger
......@@ -15,21 +13,15 @@ class HybridRouteLayer:
index = None
sparse_index = None
categories = None
score_threshold = 0.82
score_threshold: float
def __init__(
self, encoder: BaseEncoder, routes: list[Route] = [], alpha: float = 0.3
):
self.encoder = encoder
self.score_threshold = self.encoder.score_threshold
self.sparse_encoder = BM25Encoder()
self.alpha = alpha
# decide on default threshold based on encoder
if isinstance(encoder, OpenAIEncoder):
self.score_threshold = 0.82
elif isinstance(encoder, CohereEncoder):
self.score_threshold = 0.3
else:
self.score_threshold = 0.82
# if routes list has been passed, we initialize index now
if routes:
# initialize index now
......
......@@ -19,7 +19,7 @@ def mock_encoder_call(utterances):
@pytest.fixture
def base_encoder():
return BaseEncoder(name="test-encoder")
return BaseEncoder(name="test-encoder", score_threshold=0.5)
@pytest.fixture
......@@ -46,6 +46,7 @@ class TestHybridRouteLayer:
def test_initialization(self, openai_encoder, routes):
route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
assert route_layer.index is not None and route_layer.categories is not None
assert openai_encoder.score_threshold == 0.82
assert route_layer.score_threshold == 0.82
assert len(route_layer.index) == 5
assert len(set(route_layer.categories)) == 2
......@@ -112,7 +113,8 @@ class TestHybridRouteLayer:
def test_failover_score_threshold(self, base_encoder):
route_layer = HybridRouteLayer(encoder=base_encoder)
assert route_layer.score_threshold == 0.82
assert base_encoder.score_threshold == 0.50
assert route_layer.score_threshold == 0.50
# Add more tests for edge cases and error handling as needed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment