From 3698b07fb4b67aa0a2c8242b488f504168170792 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?= <Danielgriffiths1790@gmail.com> Date: Mon, 8 Jan 2024 15:35:11 +0000 Subject: [PATCH] feat: added add route tfidf test --- tests/unit/test_hybrid_layer.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 567c5492..027d8750 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -30,11 +30,6 @@ def base_encoder(mocker): return mock_base_encoder -# @pytest.fixture -# def base_encoder(): -# return BaseEncoder(name="test-encoder") - - @pytest.fixture def cohere_encoder(mocker): mocker.patch.object(CohereEncoder, "__call__", side_effect=mock_encoder_call) @@ -165,5 +160,18 @@ class TestHybridRouteLayer: ) assert route_layer.score_threshold == 0.82 + def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes): + hybrid_route_layer = HybridRouteLayer( + dense_encoder=cohere_encoder, + sparse_encoder=tfidf_encoder, + routes=routes[:-1], + ) + hybrid_route_layer.add(routes[-1]) + all_utterances = [ + utterance for route in routes for utterance in route.utterances + ] + assert hybrid_route_layer.sparse_index is not None + assert len(hybrid_route_layer.sparse_index) == len(all_utterances) + # Add more tests for edge cases and error handling as needed. -- GitLab