From 7e4fe6d5cb2eeee3d0fb0e6bec93a93289840a57 Mon Sep 17 00:00:00 2001
From: James Briggs <35938317+jamescalam@users.noreply.github.com>
Date: Thu, 28 Dec 2023 10:54:15 +0100
Subject: [PATCH] add env var to other tests

---
 tests/unit/test_layer.py | 19 ++++++++++---------
 1 file changed, 10 insertions(+), 9 deletions(-)

diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index 43f8a0c1..1c8c18f2 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -1,6 +1,6 @@
 import os
 import pytest
-from unittest.mock import Mock, mock_open, patch
+from unittest.mock import mock_open, patch
 
 from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
 from semantic_router.layer import LayerConfig, RouteLayer
@@ -186,6 +186,7 @@ class TestRouteLayer:
         os.remove("test_output.json")
 
     def test_yaml(self, openai_encoder, routes):
+        os.environ["OPENAI_API_KEY"] = "test_api_key"
         route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
         route_layer.to_yaml("test_output.yaml")
         assert os.path.exists("test_output.yaml")
@@ -197,18 +198,18 @@ class TestRouteLayer:
         os.remove("test_output.yaml")
 
     def test_config(self, openai_encoder, routes):
+        os.environ["OPENAI_API_KEY"] = "test_api_key"
         route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
         # confirm route creation functions as expected
         layer_config = route_layer.to_config()
         assert layer_config.routes == routes
-        with patch("semantic_router.encoders.OpenAIEncoder", new_callable=Mock):
-            # now load from config and confirm it's the same
-            route_layer_from_config = RouteLayer.from_config(layer_config)
-            assert (route_layer_from_config.index == route_layer.index).all()
-            assert (route_layer_from_config.categories == route_layer.categories).all()
-            assert (
-                route_layer_from_config.score_threshold == route_layer.score_threshold
-            )
+        # now load from config and confirm it's the same
+        route_layer_from_config = RouteLayer.from_config(layer_config)
+        assert (route_layer_from_config.index == route_layer.index).all()
+        assert (route_layer_from_config.categories == route_layer.categories).all()
+        assert (
+            route_layer_from_config.score_threshold == route_layer.score_threshold
+        )
 
 
 # Add more tests for edge cases and error handling as needed.
-- 
GitLab