From 8597dee0c170360ea77c666443afda8d34e481a9 Mon Sep 17 00:00:00 2001 From: James McKeown <jmckeown@watscoventures.com> Date: Mon, 8 Jan 2024 09:11:19 -0500 Subject: [PATCH] make deployment_name optional in AzureOpenAIEncoder --- semantic_router/encoders/__init__.py | 2 +- semantic_router/encoders/fastembed.py | 1 + semantic_router/encoders/zure.py | 6 ++---- semantic_router/schema.py | 1 - semantic_router/utils/splitters.py | 1 + tests/unit/test_splitters.py | 6 ++++-- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index d3fa4188..873d24e0 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,8 +1,8 @@ from semantic_router.encoders.base import BaseEncoder from semantic_router.encoders.bm25 import BM25Encoder from semantic_router.encoders.cohere import CohereEncoder -from semantic_router.encoders.openai import OpenAIEncoder from semantic_router.encoders.fastembed import FastEmbedEncoder +from semantic_router.encoders.openai import OpenAIEncoder from semantic_router.encoders.zure import AzureOpenAIEncoder __all__ = [ diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py index d324058d..4bb46b85 100644 --- a/semantic_router/encoders/fastembed.py +++ b/semantic_router/encoders/fastembed.py @@ -1,4 +1,5 @@ from typing import Any, List, Optional + import numpy as np from pydantic import BaseModel, PrivateAttr diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index 792d16f0..b4949bf3 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -43,8 +43,7 @@ class AzureOpenAIEncoder(BaseEncoder): self.deployment_name = os.getenv( "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002" ) - if self.deployment_name is None: - raise ValueError("No Azure OpenAI deployment name provided.") + # deployment_name may still be None, but it is optional in the API if self.azure_endpoint is None: self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") if self.azure_endpoint is None: @@ -59,7 +58,6 @@ class AzureOpenAIEncoder(BaseEncoder): raise ValueError("No Azure OpenAI model provided.") assert ( self.api_key is not None - and self.deployment_name is not None and self.azure_endpoint is not None and self.api_version is not None and self.model is not None @@ -67,7 +65,7 @@ class AzureOpenAIEncoder(BaseEncoder): try: self.client = openai.AzureOpenAI( - azure_deployment=str(deployment_name), + azure_deployment=str(deployment_name) if deployment_name else None, api_key=str(api_key), azure_endpoint=str(azure_endpoint), api_version=str(api_version), diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 465cfaac..360442f6 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -8,7 +8,6 @@ from semantic_router.encoders import ( CohereEncoder, OpenAIEncoder, ) - from semantic_router.utils.splitters import semantic_splitter diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py index 514ae821..f299ff12 100644 --- a/semantic_router/utils/splitters.py +++ b/semantic_router/utils/splitters.py @@ -1,4 +1,5 @@ import numpy as np + from semantic_router.encoders import BaseEncoder diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py index bcd8f62b..ac9c037c 100644 --- a/tests/unit/test_splitters.py +++ b/tests/unit/test_splitters.py @@ -1,7 +1,9 @@ -import pytest from unittest.mock import Mock -from semantic_router.utils.splitters import semantic_splitter + +import pytest + from semantic_router.schema import Conversation, Message +from semantic_router.utils.splitters import semantic_splitter def test_semantic_splitter_consecutive_similarity_drop(): -- GitLab