From 9854286c9ddb70362e38715e08a6cd662dc06c30 Mon Sep 17 00:00:00 2001 From: James Briggs <james.briggs@hotmail.com> Date: Thu, 13 Jun 2024 15:44:32 +0800 Subject: [PATCH] feat: add default openai values --- semantic_router/encoders/base.py | 1 - semantic_router/encoders/openai.py | 16 ++++++++++++---- semantic_router/schema.py | 1 + 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 3fe68d8c..bb9ce207 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -7,7 +7,6 @@ class BaseEncoder(BaseModel): name: str score_threshold: float type: str = Field(default="base") - is_async: bool = Field(default=False) class Config: arbitrary_types_allowed = True diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 6b2aeab3..ad328447 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -18,13 +18,16 @@ from semantic_router.utils.logger import logger model_configs = { "text-embedding-ada-002": EncoderInfo( - name="text-embedding-ada-002", token_limit=8192 + name="text-embedding-ada-002", token_limit=8192, + threshold=0.82, ), "text-embedding-3-small": EncoderInfo( - name="text-embedding-3-small", token_limit=8192 + name="text-embedding-3-small", token_limit=8192, + threshold=0.3, ), "text-embedding-3-large": EncoderInfo( - name="text-embedding-3-large", token_limit=8192 + name="text-embedding-3-large", token_limit=8192, + threshold=0.3, ), } @@ -43,11 +46,16 @@ class OpenAIEncoder(BaseEncoder): openai_base_url: Optional[str] = None, openai_api_key: Optional[str] = None, openai_org_id: Optional[str] = None, - score_threshold: float = 0.82, + score_threshold: Optional[float] = None, dimensions: Union[int, NotGiven] = NotGiven(), ): if name is None: name = EncoderDefault.OPENAI.value["embedding_model"] + if score_threshold is None and name in model_configs: + score_threshold = model_configs[name].threshold + elif score_threshold is None: + logger.warning(f"Score threshold not set for model: {name}. Using default value.") + score_threshold = 0.82 super().__init__( name=name, score_threshold=score_threshold, diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 86ab1233..b444c988 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -21,6 +21,7 @@ class EncoderType(Enum): class EncoderInfo(BaseModel): name: str token_limit: int + threshold: Optional[float] = None class RouteChoice(BaseModel): -- GitLab