diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index 3fe68d8ca7c104e2ebcbb686a0762b169d66cf4a..bb9ce20738135b7273900b81390497db7b9b61d7 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -7,7 +7,6 @@ class BaseEncoder(BaseModel): name: str score_threshold: float type: str = Field(default="base") - is_async: bool = Field(default=False) class Config: arbitrary_types_allowed = True diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 6b2aeab38bf143687cb08bfdc92d742237aafd77..ad32844765705f6d9f6517ddecddeeb2044a6748 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -18,13 +18,16 @@ from semantic_router.utils.logger import logger model_configs = { "text-embedding-ada-002": EncoderInfo( - name="text-embedding-ada-002", token_limit=8192 + name="text-embedding-ada-002", token_limit=8192, + threshold=0.82, ), "text-embedding-3-small": EncoderInfo( - name="text-embedding-3-small", token_limit=8192 + name="text-embedding-3-small", token_limit=8192, + threshold=0.3, ), "text-embedding-3-large": EncoderInfo( - name="text-embedding-3-large", token_limit=8192 + name="text-embedding-3-large", token_limit=8192, + threshold=0.3, ), } @@ -43,11 +46,16 @@ class OpenAIEncoder(BaseEncoder): openai_base_url: Optional[str] = None, openai_api_key: Optional[str] = None, openai_org_id: Optional[str] = None, - score_threshold: float = 0.82, + score_threshold: Optional[float] = None, dimensions: Union[int, NotGiven] = NotGiven(), ): if name is None: name = EncoderDefault.OPENAI.value["embedding_model"] + if score_threshold is None and name in model_configs: + score_threshold = model_configs[name].threshold + elif score_threshold is None: + logger.warning(f"Score threshold not set for model: {name}. Using default value.") + score_threshold = 0.82 super().__init__( name=name, score_threshold=score_threshold, diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 86ab123318812dbc3c18a8a6f98a7e4bb186669c..b444c98884089f77935775c135234eaf5552b85e 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -21,6 +21,7 @@ class EncoderType(Enum): class EncoderInfo(BaseModel): name: str token_limit: int + threshold: Optional[float] = None class RouteChoice(BaseModel):