Skip to content
Snippets Groups Projects
Unverified Commit f437eeae authored by Gal Dayan's avatar Gal Dayan Committed by GitHub
Browse files

Merge pull request #1 from GENWAY-AI/add_dimension_azure

Add dimension azure
parents a8544d8f 82c06d51
Branches
Tags
No related merge requests found
import os
from time import sleep
from typing import List, Optional
from typing import List, Optional, Union
import openai
from openai._types import NotGiven
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse
......@@ -13,6 +14,7 @@ from semantic_router.utils.logger import logger
class AzureOpenAIEncoder(BaseEncoder):
client: Optional[openai.AzureOpenAI] = None
dimensions: Union[int, NotGiven] = NotGiven()
type: str = "azure"
api_key: Optional[str] = None
deployment_name: Optional[str] = None
......@@ -28,6 +30,7 @@ class AzureOpenAIEncoder(BaseEncoder):
api_version: Optional[str] = None,
model: Optional[str] = None, # TODO we should change to `name` JB
score_threshold: float = 0.82,
dimensions: Union[int, NotGiven] = NotGiven(),
):
name = deployment_name
if name is None:
......@@ -38,6 +41,8 @@ class AzureOpenAIEncoder(BaseEncoder):
self.azure_endpoint = azure_endpoint
self.api_version = api_version
self.model = model
# set dimensions to support openai embed 3 dimensions param
self.dimensions = dimensions
if self.api_key is None:
self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
if self.api_key is None:
......@@ -89,7 +94,7 @@ class AzureOpenAIEncoder(BaseEncoder):
for j in range(3):
try:
embeds = self.client.embeddings.create(
input=docs, model=str(self.model)
input=docs, model=str(self.model), dimensions=self.dimensions,
)
if embeds.data:
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment