From 6cd8f4681f0b7417d1a7129a8f6750c47dc28267 Mon Sep 17 00:00:00 2001 From: Stephen Witkowski <stephen.witkowski@66degrees.com> Date: Fri, 29 Mar 2024 16:11:44 -0400 Subject: [PATCH] Refactor GoogleEncoder constructor parameter name --- semantic_router/encoders/google.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/semantic_router/encoders/google.py b/semantic_router/encoders/google.py index fcb5997d..d1d6879f 100644 --- a/semantic_router/encoders/google.py +++ b/semantic_router/encoders/google.py @@ -38,7 +38,7 @@ class GoogleEncoder(BaseEncoder): def __init__( self, - model_name: Optional[str] = None, + name: Optional[str] = None, score_threshold: float = 0.3, project_id: Optional[str] = None, location: Optional[str] = None, @@ -61,10 +61,10 @@ class GoogleEncoder(BaseEncoder): Raises: ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize. """ - if model_name is None: - model_name = EncoderDefault.GOOGLE.value["embedding_model"] + if name is None: + name = EncoderDefault.GOOGLE.value["embedding_model"] - super().__init__(model_name=model_name, score_threshold=score_threshold) + super().__init__(name=name, score_threshold=score_threshold) project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") location = location or os.getenv("GOOGLE_LOCATION", "us-central1") @@ -76,7 +76,7 @@ class GoogleEncoder(BaseEncoder): aiplatform.init( project=project_id, location=location, api_endpoint=api_endpoint ) - self.client = TextEmbeddingModel.from_pretrained(self.model_name) + self.client = TextEmbeddingModel.from_pretrained(self.name) except Exception as e: raise ValueError( f"Google AI Platform client failed to initialize. Error: {e}" -- GitLab