From b9a4f2c4c9563a4a9ca7ce7ab8e876a2251f5f0a Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Wed, 14 Feb 2024 03:24:36 +0400
Subject: [PATCH] fix tests

---
 tests/unit/test_layer.py | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index 9ddb1fe0..e5ba76a2 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -157,14 +157,18 @@ class TestRouteLayer:
 
         # Initially, the routes list should be empty
         assert route_layer.routes == []
+
         # Add route1 and check
         route_layer.add(route=route1)
         assert route_layer.routes == [route1]
-        # TODO add length check
+        assert route_layer.index is not None
+        # Use the describe method to get the number of vectors
+        assert route_layer.index.describe()["vectors"] == 2
+
         # Add route2 and check
         route_layer.add(route=route2)
         assert route_layer.routes == [route1, route2]
-        # TODO add length check
+        assert route_layer.index.describe()["vectors"] == 4
 
     def test_list_route_names(self, openai_encoder, routes):
         route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
@@ -202,7 +206,7 @@ class TestRouteLayer:
         route_layer = RouteLayer(encoder=openai_encoder)
         route_layer._add_routes(routes=routes)
         assert route_layer.index is not None
-        # TODO add length check
+        assert route_layer.index.describe()["vectors"] == 5
 
     def test_query_and_classification(self, openai_encoder, routes):
         route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
@@ -211,7 +215,8 @@ class TestRouteLayer:
 
     def test_query_with_no_index(self, openai_encoder):
         route_layer = RouteLayer(encoder=openai_encoder)
-        assert route_layer(text="Anything").name is None
+        with pytest.raises(ValueError):
+            assert route_layer(text="Anything").name is None
 
     def test_query_with_vector(self, openai_encoder, routes):
         route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
@@ -296,7 +301,7 @@ class TestRouteLayer:
         assert layer_config.routes == routes
         # now load from config and confirm it's the same
         route_layer_from_config = RouteLayer.from_config(layer_config)
-        assert route_layer_from_config.index == route_layer.index
+        assert (route_layer_from_config.index.index == route_layer.index.index).all()
         assert route_layer_from_config._get_route_names() == route_layer._get_route_names()
         assert route_layer_from_config.score_threshold == route_layer.score_threshold
 
-- 
GitLab