From 8597dee0c170360ea77c666443afda8d34e481a9 Mon Sep 17 00:00:00 2001
From: James McKeown <jmckeown@watscoventures.com>
Date: Mon, 8 Jan 2024 09:11:19 -0500
Subject: [PATCH] make deployment_name optional in AzureOpenAIEncoder

---
 semantic_router/encoders/__init__.py  | 2 +-
 semantic_router/encoders/fastembed.py | 1 +
 semantic_router/encoders/zure.py      | 6 ++----
 semantic_router/schema.py             | 1 -
 semantic_router/utils/splitters.py    | 1 +
 tests/unit/test_splitters.py          | 6 ++++--
 6 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index d3fa4188..873d24e0 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -1,8 +1,8 @@
 from semantic_router.encoders.base import BaseEncoder
 from semantic_router.encoders.bm25 import BM25Encoder
 from semantic_router.encoders.cohere import CohereEncoder
-from semantic_router.encoders.openai import OpenAIEncoder
 from semantic_router.encoders.fastembed import FastEmbedEncoder
+from semantic_router.encoders.openai import OpenAIEncoder
 from semantic_router.encoders.zure import AzureOpenAIEncoder
 
 __all__ = [
diff --git a/semantic_router/encoders/fastembed.py b/semantic_router/encoders/fastembed.py
index d324058d..4bb46b85 100644
--- a/semantic_router/encoders/fastembed.py
+++ b/semantic_router/encoders/fastembed.py
@@ -1,4 +1,5 @@
 from typing import Any, List, Optional
+
 import numpy as np
 from pydantic import BaseModel, PrivateAttr
 
diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py
index 792d16f0..b4949bf3 100644
--- a/semantic_router/encoders/zure.py
+++ b/semantic_router/encoders/zure.py
@@ -43,8 +43,7 @@ class AzureOpenAIEncoder(BaseEncoder):
             self.deployment_name = os.getenv(
                 "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
             )
-            if self.deployment_name is None:
-                raise ValueError("No Azure OpenAI deployment name provided.")
+        # deployment_name may still be None, but it is optional in the API
         if self.azure_endpoint is None:
             self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
             if self.azure_endpoint is None:
@@ -59,7 +58,6 @@ class AzureOpenAIEncoder(BaseEncoder):
                 raise ValueError("No Azure OpenAI model provided.")
         assert (
             self.api_key is not None
-            and self.deployment_name is not None
             and self.azure_endpoint is not None
             and self.api_version is not None
             and self.model is not None
@@ -67,7 +65,7 @@ class AzureOpenAIEncoder(BaseEncoder):
 
         try:
             self.client = openai.AzureOpenAI(
-                azure_deployment=str(deployment_name),
+                azure_deployment=str(deployment_name) if deployment_name else None,
                 api_key=str(api_key),
                 azure_endpoint=str(azure_endpoint),
                 api_version=str(api_version),
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 465cfaac..360442f6 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -8,7 +8,6 @@ from semantic_router.encoders import (
     CohereEncoder,
     OpenAIEncoder,
 )
-
 from semantic_router.utils.splitters import semantic_splitter
 
 
diff --git a/semantic_router/utils/splitters.py b/semantic_router/utils/splitters.py
index 514ae821..f299ff12 100644
--- a/semantic_router/utils/splitters.py
+++ b/semantic_router/utils/splitters.py
@@ -1,4 +1,5 @@
 import numpy as np
+
 from semantic_router.encoders import BaseEncoder
 
 
diff --git a/tests/unit/test_splitters.py b/tests/unit/test_splitters.py
index bcd8f62b..ac9c037c 100644
--- a/tests/unit/test_splitters.py
+++ b/tests/unit/test_splitters.py
@@ -1,7 +1,9 @@
-import pytest
 from unittest.mock import Mock
-from semantic_router.utils.splitters import semantic_splitter
+
+import pytest
+
 from semantic_router.schema import Conversation, Message
+from semantic_router.utils.splitters import semantic_splitter
 
 
 def test_semantic_splitter_consecutive_similarity_drop():
-- 
GitLab