diff --git a/semantic_router/encoders/zure.py b/semantic_router/encoders/zure.py index df2bf858aa32578e4751fac90d698e4db111f84b..7215e7f58cb4c5426d6fbd6ffed6ff9a597122f2 100644 --- a/semantic_router/encoders/zure.py +++ b/semantic_router/encoders/zure.py @@ -1,8 +1,9 @@ import os from time import sleep -from typing import List, Optional +from typing import List, Optional, Union import openai +from openai._types import NotGiven from openai import OpenAIError from openai.types import CreateEmbeddingResponse @@ -13,6 +14,7 @@ from semantic_router.utils.logger import logger class AzureOpenAIEncoder(BaseEncoder): client: Optional[openai.AzureOpenAI] = None + dimensions: Union[int, NotGiven] = NotGiven() type: str = "azure" api_key: Optional[str] = None deployment_name: Optional[str] = None @@ -28,6 +30,7 @@ class AzureOpenAIEncoder(BaseEncoder): api_version: Optional[str] = None, model: Optional[str] = None, # TODO we should change to `name` JB score_threshold: float = 0.82, + dimensions: Union[int, NotGiven] = NotGiven(), ): name = deployment_name if name is None: @@ -38,6 +41,8 @@ class AzureOpenAIEncoder(BaseEncoder): self.azure_endpoint = azure_endpoint self.api_version = api_version self.model = model + # set dimensions to support openai embed 3 dimensions param + self.dimensions = dimensions if self.api_key is None: self.api_key = os.getenv("AZURE_OPENAI_API_KEY") if self.api_key is None: @@ -89,7 +94,7 @@ class AzureOpenAIEncoder(BaseEncoder): for j in range(3): try: embeds = self.client.embeddings.create( - input=docs, model=str(self.model) + input=docs, model=str(self.model), dimensions=self.dimensions, ) if embeds.data: break