Skip to content
Snippets Groups Projects
Commit 8597dee0 authored by James McKeown's avatar James McKeown
Browse files

make deployment_name optional in AzureOpenAIEncoder

parent 9c766ad1
Branches
Tags
No related merge requests found
from semantic_router.encoders.base import BaseEncoder
from semantic_router.encoders.bm25 import BM25Encoder
from semantic_router.encoders.cohere import CohereEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from semantic_router.encoders.fastembed import FastEmbedEncoder
from semantic_router.encoders.openai import OpenAIEncoder
from semantic_router.encoders.zure import AzureOpenAIEncoder
__all__ = [
......
from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel, PrivateAttr
......
......@@ -43,8 +43,7 @@ class AzureOpenAIEncoder(BaseEncoder):
self.deployment_name = os.getenv(
"AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
)
if self.deployment_name is None:
raise ValueError("No Azure OpenAI deployment name provided.")
# deployment_name may still be None, but it is optional in the API
if self.azure_endpoint is None:
self.azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
if self.azure_endpoint is None:
......@@ -59,7 +58,6 @@ class AzureOpenAIEncoder(BaseEncoder):
raise ValueError("No Azure OpenAI model provided.")
assert (
self.api_key is not None
and self.deployment_name is not None
and self.azure_endpoint is not None
and self.api_version is not None
and self.model is not None
......@@ -67,7 +65,7 @@ class AzureOpenAIEncoder(BaseEncoder):
try:
self.client = openai.AzureOpenAI(
azure_deployment=str(deployment_name),
azure_deployment=str(deployment_name) if deployment_name else None,
api_key=str(api_key),
azure_endpoint=str(azure_endpoint),
api_version=str(api_version),
......
......@@ -8,7 +8,6 @@ from semantic_router.encoders import (
CohereEncoder,
OpenAIEncoder,
)
from semantic_router.utils.splitters import semantic_splitter
......
import numpy as np
from semantic_router.encoders import BaseEncoder
......
import pytest
from unittest.mock import Mock
from semantic_router.utils.splitters import semantic_splitter
import pytest
from semantic_router.schema import Conversation, Message
from semantic_router.utils.splitters import semantic_splitter
def test_semantic_splitter_consecutive_similarity_drop():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment