Skip to content
Snippets Groups Projects
Commit 6cd8f468 authored by Stephen Witkowski's avatar Stephen Witkowski
Browse files

Refactor GoogleEncoder constructor parameter name

parent 043c0458
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,7 @@ class GoogleEncoder(BaseEncoder):
def __init__(
self,
model_name: Optional[str] = None,
name: Optional[str] = None,
score_threshold: float = 0.3,
project_id: Optional[str] = None,
location: Optional[str] = None,
......@@ -61,10 +61,10 @@ class GoogleEncoder(BaseEncoder):
Raises:
ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize.
"""
if model_name is None:
model_name = EncoderDefault.GOOGLE.value["embedding_model"]
if name is None:
name = EncoderDefault.GOOGLE.value["embedding_model"]
super().__init__(model_name=model_name, score_threshold=score_threshold)
super().__init__(name=name, score_threshold=score_threshold)
project_id = project_id or os.getenv("GOOGLE_PROJECT_ID")
location = location or os.getenv("GOOGLE_LOCATION", "us-central1")
......@@ -76,7 +76,7 @@ class GoogleEncoder(BaseEncoder):
aiplatform.init(
project=project_id, location=location, api_endpoint=api_endpoint
)
self.client = TextEmbeddingModel.from_pretrained(self.model_name)
self.client = TextEmbeddingModel.from_pretrained(self.name)
except Exception as e:
raise ValueError(
f"Google AI Platform client failed to initialize. Error: {e}"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment