From c2a3b3e1da4978b24088b63190032b73da26ab89 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Mon, 8 Apr 2024 12:19:25 +0300 Subject: [PATCH] fix: Google client dependencies --- semantic_router/encoders/google.py | 71 +++++++++++++++++++++------- semantic_router/encoders/mistral.py | 4 +- semantic_router/llms/llamacpp.py | 4 +- semantic_router/llms/mistral.py | 5 +- semantic_router/schema.py | 2 +- tests/unit/encoders/test_google.py | 2 +- tests/unit/encoders/test_mistral.py | 4 +- tests/unit/llms/test_llm_llamacpp.py | 4 +- tests/unit/llms/test_llm_mistral.py | 4 +- 9 files changed, 67 insertions(+), 33 deletions(-) diff --git a/semantic_router/encoders/google.py b/semantic_router/encoders/google.py index 42996dee..088d4bba 100644 --- a/semantic_router/encoders/google.py +++ b/semantic_router/encoders/google.py @@ -17,10 +17,7 @@ Classes: """ import os -from typing import List, Optional - -from google.cloud import aiplatform -from vertexai.language_models import TextEmbeddingModel +from typing import Any, List, Optional from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault @@ -34,7 +31,7 @@ class GoogleEncoder(BaseEncoder): type: The type of the encoder, which is "google". """ - client: Optional[TextEmbeddingModel] = None + client: Optional[Any] = None type: str = "google" def __init__( @@ -49,39 +46,75 @@ class GoogleEncoder(BaseEncoder): Args: model_name: The name of the pre-trained model to use for embedding. - If not provided, the default model specified in EncoderDefault will be used. - score_threshold: The threshold for similarity scores. Default is 0.3. + If not provided, the default model specified in EncoderDefault will + be used. + score_threshold: The threshold for similarity scores. project_id: The Google Cloud project ID. - If not provided, it will be retrieved from the GOOGLE_PROJECT_ID environment variable. + If not provided, it will be retrieved from the GOOGLE_PROJECT_ID + environment variable. location: The location of the AI Platform resources. - If not provided, it will be retrieved from the GOOGLE_LOCATION environment variable, - defaulting to "us-central1". + If not provided, it will be retrieved from the GOOGLE_LOCATION + environment variable, defaulting to "us-central1". api_endpoint: The API endpoint for the AI Platform. - If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT environment variable. + If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT + environment variable. Raises: - ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize. + ValueError: If the Google Project ID is not provided or if the AI Platform + client fails to initialize. """ if name is None: name = EncoderDefault.GOOGLE.value["embedding_model"] super().__init__(name=name, score_threshold=score_threshold) + self.client = self._initialize_client(project_id, location, api_endpoint) + + def _initialize_client(self, project_id, location, api_endpoint): + """Initializes the Google AI Platform client. + + Args: + project_id: The Google Cloud project ID. + location: The location of the AI Platform resources. + api_endpoint: The API endpoint for the AI Platform. + + Returns: + An instance of the TextEmbeddingModel client. + + Raises: + ImportError: If the required Google Cloud or Vertex AI libraries are not + installed. + ValueError: If the Google Project ID is not provided or if the AI Platform + client fails to initialize. + """ + try: + from google.cloud import aiplatform + from vertexai.language_models import TextEmbeddingModel + except ImportError: + raise ImportError( + "Please install Google Cloud and Vertex AI libraries to use GoogleEncoder. " + "You can install them with: " + "`pip install google-cloud-aiplatform vertexai-language-models`" + ) + project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") location = location or os.getenv("GOOGLE_LOCATION", "us-central1") api_endpoint = api_endpoint or os.getenv("GOOGLE_API_ENDPOINT") if project_id is None: raise ValueError("Google Project ID cannot be 'None'.") + try: aiplatform.init( project=project_id, location=location, api_endpoint=api_endpoint ) - self.client = TextEmbeddingModel.from_pretrained(self.name) - except Exception as e: + client = TextEmbeddingModel.from_pretrained(self.name) + except Exception as err: raise ValueError( - f"Google AI Platform client failed to initialize. Error: {e}" - ) from e + f"Google AI Platform client failed to initialize. Error: {err}" + ) from err + + return client def __call__(self, docs: List[str]) -> List[List[float]]: """Generates embeddings for the given documents. @@ -90,10 +123,12 @@ class GoogleEncoder(BaseEncoder): docs: A list of strings representing the documents to embed. Returns: - A list of lists, where each inner list contains the embedding values for a document. + A list of lists, where each inner list contains the embedding values for a + document. Raises: - ValueError: If the Google AI Platform client is not initialized or if the API call fails. + ValueError: If the Google AI Platform client is not initialized or if the + API call fails. """ if self.client is None: raise ValueError("Google AI Platform client is not initialized.") diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index 544c629f..974f1128 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -2,12 +2,12 @@ import os from time import sleep -from typing import List, Optional, Any +from typing import Any, List, Optional +from pydantic.v1 import PrivateAttr from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault -from pydantic.v1 import PrivateAttr class MistralEncoder(BaseEncoder): diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index bcf01b6d..5a737b86 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -2,12 +2,12 @@ from contextlib import contextmanager from pathlib import Path from typing import Any, Optional +from pydantic.v1 import PrivateAttr + from semantic_router.llms.base import BaseLLM from semantic_router.schema import Message from semantic_router.utils.logger import logger -from pydantic.v1 import PrivateAttr - class LlamaCppLLM(BaseLLM): llm: Any diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 647d4073..8ddd1482 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -1,14 +1,13 @@ import os -from typing import List, Optional, Any +from typing import Any, List, Optional +from pydantic.v1 import PrivateAttr from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger -from pydantic.v1 import PrivateAttr - class MistralAILLM(BaseLLM): _client: Any = PrivateAttr() diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 035ca8a0..6b7485ac 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -8,9 +8,9 @@ from semantic_router.encoders import ( BaseEncoder, CohereEncoder, FastEmbedEncoder, + GoogleEncoder, MistralEncoder, OpenAIEncoder, - GoogleEncoder, ) diff --git a/tests/unit/encoders/test_google.py b/tests/unit/encoders/test_google.py index 11c19cc8..42d52b58 100644 --- a/tests/unit/encoders/test_google.py +++ b/tests/unit/encoders/test_google.py @@ -1,7 +1,7 @@ import pytest +from google.api_core.exceptions import GoogleAPICallError from vertexai.language_models import TextEmbedding from vertexai.language_models._language_models import TextEmbeddingStatistics -from google.api_core.exceptions import GoogleAPICallError from semantic_router.encoders import GoogleEncoder diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py index f36f5037..25dba6b7 100644 --- a/tests/unit/encoders/test_mistral.py +++ b/tests/unit/encoders/test_mistral.py @@ -1,11 +1,11 @@ +from unittest.mock import patch + import pytest from mistralai.exceptions import MistralException from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, UsageInfo from semantic_router.encoders import MistralEncoder -from unittest.mock import patch - @pytest.fixture def mistralai_encoder(mocker): diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index 63d92ee8..9f579cdf 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -1,11 +1,11 @@ +from unittest.mock import patch + import pytest from llama_cpp import Llama from semantic_router.llms.llamacpp import LlamaCppLLM from semantic_router.schema import Message -from unittest.mock import patch - @pytest.fixture def llamacpp_llm(mocker): diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py index d73406e7..3b318e2a 100644 --- a/tests/unit/llms/test_llm_mistral.py +++ b/tests/unit/llms/test_llm_mistral.py @@ -1,10 +1,10 @@ +from unittest.mock import patch + import pytest from semantic_router.llms import MistralAILLM from semantic_router.schema import Message -from unittest.mock import patch - @pytest.fixture def mistralai_llm(mocker): -- GitLab