Skip to content
Snippets Groups Projects
Unverified Commit 9a4bea83 authored by James Briggs's avatar James Briggs
Browse files

add mock for encoder in new tests

parent 9a9249a2
No related branches found
No related tags found
No related merge requests found
import os
import pytest
from unittest.mock import mock_open, patch
from unittest.mock import Mock, mock_open, patch
from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
from semantic_router.layer import LayerConfig, RouteLayer
......@@ -177,22 +177,24 @@ class TestRouteLayer:
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
route_layer.to_json("test_output.json")
assert os.path.exists("test_output.json")
route_layer_from_file = RouteLayer.from_json("test_output.json")
assert (
route_layer_from_file.index is not None
and route_layer_from_file.categories is not None
)
with patch("semantic_router.schema.Encoder", new_callable=Mock):
route_layer_from_file = RouteLayer.from_json("test_output.json")
assert (
route_layer_from_file.index is not None
and route_layer_from_file.categories is not None
)
os.remove("test_output.json")
def test_yaml(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
route_layer.to_yaml("test_output.yaml")
assert os.path.exists("test_output.yaml")
route_layer_from_file = RouteLayer.from_yaml("test_output.yaml")
assert (
route_layer_from_file.index is not None
and route_layer_from_file.categories is not None
)
with patch("semantic_router.schema.Encoder", new_callable=Mock):
route_layer_from_file = RouteLayer.from_yaml("test_output.yaml")
assert (
route_layer_from_file.index is not None
and route_layer_from_file.categories is not None
)
os.remove("test_output.yaml")
def test_config(self, openai_encoder, routes):
......@@ -200,11 +202,12 @@ class TestRouteLayer:
# confirm route creation functions as expected
layer_config = route_layer.to_config()
assert layer_config.routes == routes
# now load from config and confirm it's the same
route_layer_from_config = RouteLayer.from_config(layer_config)
assert (route_layer_from_config.index == route_layer.index).all()
assert (route_layer_from_config.categories == route_layer.categories).all()
assert route_layer_from_config.score_threshold == route_layer.score_threshold
with patch("semantic_router.schema.Encoder", new_callable=Mock):
# now load from config and confirm it's the same
route_layer_from_config = RouteLayer.from_config(layer_config)
assert (route_layer_from_config.index == route_layer.index).all()
assert (route_layer_from_config.categories == route_layer.categories).all()
assert route_layer_from_config.score_threshold == route_layer.score_threshold
# Add more tests for edge cases and error handling as needed.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment