From 3698b07fb4b67aa0a2c8242b488f504168170792 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E2=80=9CDaniel=20Griffiths=E2=80=9D?=
 <Danielgriffiths1790@gmail.com>
Date: Mon, 8 Jan 2024 15:35:11 +0000
Subject: [PATCH] feat: added add route tfidf test

---
 tests/unit/test_hybrid_layer.py | 18 +++++++++++++-----
 1 file changed, 13 insertions(+), 5 deletions(-)

diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py
index 567c5492..027d8750 100644
--- a/tests/unit/test_hybrid_layer.py
+++ b/tests/unit/test_hybrid_layer.py
@@ -30,11 +30,6 @@ def base_encoder(mocker):
     return mock_base_encoder
 
 
-# @pytest.fixture
-# def base_encoder():
-#     return BaseEncoder(name="test-encoder")
-
-
 @pytest.fixture
 def cohere_encoder(mocker):
     mocker.patch.object(CohereEncoder, "__call__", side_effect=mock_encoder_call)
@@ -165,5 +160,18 @@ class TestHybridRouteLayer:
         )
         assert route_layer.score_threshold == 0.82
 
+    def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes):
+        hybrid_route_layer = HybridRouteLayer(
+            dense_encoder=cohere_encoder,
+            sparse_encoder=tfidf_encoder,
+            routes=routes[:-1],
+        )
+        hybrid_route_layer.add(routes[-1])
+        all_utterances = [
+            utterance for route in routes for utterance in route.utterances
+        ]
+        assert hybrid_route_layer.sparse_index is not None
+        assert len(hybrid_route_layer.sparse_index) == len(all_utterances)
+
 
 # Add more tests for edge cases and error handling as needed.
-- 
GitLab