From 975fc4ecf5197e3ef934d06fda8a35cb1a309c4c Mon Sep 17 00:00:00 2001 From: James McKeown <jmckeown@watscoventures.com> Date: Thu, 4 Jan 2024 22:17:29 -0500 Subject: [PATCH] do not require strict response validation --- semantic_router/encoders/zure.py | 6 +++++- tests/unit/test_hybrid_layer.py | 19 ++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index 5b8099b9..792d16f0 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -71,7 +71,7 @@ class AzureOpenAIEncoder(BaseEncoder): api_key=str(api_key), azure_endpoint=str(azure_endpoint), api_version=str(api_version), - _strict_response_validation=True, + # _strict_response_validation=True, ) except Exception as e: raise ValueError(f"OpenAI API client failed to initialize. Error: {e}") @@ -91,6 +91,10 @@ class AzureOpenAIEncoder(BaseEncoder): if embeds.data: break except OpenAIError as e: + # print full traceback + import traceback + + traceback.print_exc() sleep(2**j) error_message = str(e) logger.warning(f"Retrying in {2**j} seconds...") diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index f87cb1d2..fbaec6c2 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -1,6 +1,11 @@ import pytest -from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder +from semantic_router.encoders import ( + AzureOpenAIEncoder, + BaseEncoder, + CohereEncoder, + OpenAIEncoder, +) from semantic_router.hybrid_layer import HybridRouteLayer from semantic_router.route import Route @@ -34,6 +39,18 @@ def openai_encoder(mocker): return OpenAIEncoder(name="test-openai-encoder", openai_api_key="test_api_key") +@pytest.fixture +def azure_encoder(mocker): + mocker.patch.object(AzureOpenAIEncoder, "__call__", side_effect=mock_encoder_call) + return AzureOpenAIEncoder( + deployment_name="test-deployment", + azure_endpoint="test_endpoint", + api_key="test_api_key", + api_version="test_version", + model="test_model", + ) + + @pytest.fixture def routes(): return [ -- GitLab