From 9854286c9ddb70362e38715e08a6cd662dc06c30 Mon Sep 17 00:00:00 2001
From: James Briggs <james.briggs@hotmail.com>
Date: Thu, 13 Jun 2024 15:44:32 +0800
Subject: [PATCH] feat: add default openai values

---
 semantic_router/encoders/base.py   |  1 -
 semantic_router/encoders/openai.py | 16 ++++++++++++----
 semantic_router/schema.py          |  1 +
 3 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index 3fe68d8c..bb9ce207 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -7,7 +7,6 @@ class BaseEncoder(BaseModel):
     name: str
     score_threshold: float
     type: str = Field(default="base")
-    is_async: bool = Field(default=False)
 
     class Config:
         arbitrary_types_allowed = True
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index 6b2aeab3..ad328447 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -18,13 +18,16 @@ from semantic_router.utils.logger import logger
 
 model_configs = {
     "text-embedding-ada-002": EncoderInfo(
-        name="text-embedding-ada-002", token_limit=8192
+        name="text-embedding-ada-002", token_limit=8192,
+        threshold=0.82,
     ),
     "text-embedding-3-small": EncoderInfo(
-        name="text-embedding-3-small", token_limit=8192
+        name="text-embedding-3-small", token_limit=8192,
+        threshold=0.3,
     ),
     "text-embedding-3-large": EncoderInfo(
-        name="text-embedding-3-large", token_limit=8192
+        name="text-embedding-3-large", token_limit=8192,
+        threshold=0.3,
     ),
 }
 
@@ -43,11 +46,16 @@ class OpenAIEncoder(BaseEncoder):
         openai_base_url: Optional[str] = None,
         openai_api_key: Optional[str] = None,
         openai_org_id: Optional[str] = None,
-        score_threshold: float = 0.82,
+        score_threshold: Optional[float] = None,
         dimensions: Union[int, NotGiven] = NotGiven(),
     ):
         if name is None:
             name = EncoderDefault.OPENAI.value["embedding_model"]
+        if score_threshold is None and name in model_configs:
+            score_threshold = model_configs[name].threshold
+        elif score_threshold is None:
+            logger.warning(f"Score threshold not set for model: {name}. Using default value.")
+            score_threshold = 0.82
         super().__init__(
             name=name,
             score_threshold=score_threshold,
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 86ab1233..b444c988 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -21,6 +21,7 @@ class EncoderType(Enum):
 class EncoderInfo(BaseModel):
     name: str
     token_limit: int
+    threshold: Optional[float] = None
 
 
 class RouteChoice(BaseModel):
-- 
GitLab