Skip to content
Snippets Groups Projects
Commit 6cd8f468 authored by Stephen Witkowski's avatar Stephen Witkowski
Browse files

Refactor GoogleEncoder constructor parameter name

parent 043c0458
No related branches found
No related tags found
No related merge requests found
...@@ -38,7 +38,7 @@ class GoogleEncoder(BaseEncoder): ...@@ -38,7 +38,7 @@ class GoogleEncoder(BaseEncoder):
def __init__( def __init__(
self, self,
model_name: Optional[str] = None, name: Optional[str] = None,
score_threshold: float = 0.3, score_threshold: float = 0.3,
project_id: Optional[str] = None, project_id: Optional[str] = None,
location: Optional[str] = None, location: Optional[str] = None,
...@@ -61,10 +61,10 @@ class GoogleEncoder(BaseEncoder): ...@@ -61,10 +61,10 @@ class GoogleEncoder(BaseEncoder):
Raises: Raises:
ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize. ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize.
""" """
if model_name is None: if name is None:
model_name = EncoderDefault.GOOGLE.value["embedding_model"] name = EncoderDefault.GOOGLE.value["embedding_model"]
super().__init__(model_name=model_name, score_threshold=score_threshold) super().__init__(name=name, score_threshold=score_threshold)
project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") project_id = project_id or os.getenv("GOOGLE_PROJECT_ID")
location = location or os.getenv("GOOGLE_LOCATION", "us-central1") location = location or os.getenv("GOOGLE_LOCATION", "us-central1")
...@@ -76,7 +76,7 @@ class GoogleEncoder(BaseEncoder): ...@@ -76,7 +76,7 @@ class GoogleEncoder(BaseEncoder):
aiplatform.init( aiplatform.init(
project=project_id, location=location, api_endpoint=api_endpoint project=project_id, location=location, api_endpoint=api_endpoint
) )
self.client = TextEmbeddingModel.from_pretrained(self.model_name) self.client = TextEmbeddingModel.from_pretrained(self.name)
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"Google AI Platform client failed to initialize. Error: {e}" f"Google AI Platform client failed to initialize. Error: {e}"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment