diff --git a/semantic_router/encoders/google.py b/semantic_router/encoders/google.py index 42996dee51f640d269ab5bc8908b8a753045e71d..088d4bba943360202c455a1bb4fbf1b6dc51b927 100644 --- a/semantic_router/encoders/google.py +++ b/semantic_router/encoders/google.py @@ -17,10 +17,7 @@ Classes: """ import os -from typing import List, Optional - -from google.cloud import aiplatform -from vertexai.language_models import TextEmbeddingModel +from typing import Any, List, Optional from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault @@ -34,7 +31,7 @@ class GoogleEncoder(BaseEncoder): type: The type of the encoder, which is "google". """ - client: Optional[TextEmbeddingModel] = None + client: Optional[Any] = None type: str = "google" def __init__( @@ -49,39 +46,75 @@ class GoogleEncoder(BaseEncoder): Args: model_name: The name of the pre-trained model to use for embedding. - If not provided, the default model specified in EncoderDefault will be used. - score_threshold: The threshold for similarity scores. Default is 0.3. + If not provided, the default model specified in EncoderDefault will + be used. + score_threshold: The threshold for similarity scores. project_id: The Google Cloud project ID. - If not provided, it will be retrieved from the GOOGLE_PROJECT_ID environment variable. + If not provided, it will be retrieved from the GOOGLE_PROJECT_ID + environment variable. location: The location of the AI Platform resources. - If not provided, it will be retrieved from the GOOGLE_LOCATION environment variable, - defaulting to "us-central1". + If not provided, it will be retrieved from the GOOGLE_LOCATION + environment variable, defaulting to "us-central1". api_endpoint: The API endpoint for the AI Platform. - If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT environment variable. + If not provided, it will be retrieved from the GOOGLE_API_ENDPOINT + environment variable. Raises: - ValueError: If the Google Project ID is not provided or if the AI Platform client fails to initialize. + ValueError: If the Google Project ID is not provided or if the AI Platform + client fails to initialize. """ if name is None: name = EncoderDefault.GOOGLE.value["embedding_model"] super().__init__(name=name, score_threshold=score_threshold) + self.client = self._initialize_client(project_id, location, api_endpoint) + + def _initialize_client(self, project_id, location, api_endpoint): + """Initializes the Google AI Platform client. + + Args: + project_id: The Google Cloud project ID. + location: The location of the AI Platform resources. + api_endpoint: The API endpoint for the AI Platform. + + Returns: + An instance of the TextEmbeddingModel client. + + Raises: + ImportError: If the required Google Cloud or Vertex AI libraries are not + installed. + ValueError: If the Google Project ID is not provided or if the AI Platform + client fails to initialize. + """ + try: + from google.cloud import aiplatform + from vertexai.language_models import TextEmbeddingModel + except ImportError: + raise ImportError( + "Please install Google Cloud and Vertex AI libraries to use GoogleEncoder. " + "You can install them with: " + "`pip install google-cloud-aiplatform vertexai-language-models`" + ) + project_id = project_id or os.getenv("GOOGLE_PROJECT_ID") location = location or os.getenv("GOOGLE_LOCATION", "us-central1") api_endpoint = api_endpoint or os.getenv("GOOGLE_API_ENDPOINT") if project_id is None: raise ValueError("Google Project ID cannot be 'None'.") + try: aiplatform.init( project=project_id, location=location, api_endpoint=api_endpoint ) - self.client = TextEmbeddingModel.from_pretrained(self.name) - except Exception as e: + client = TextEmbeddingModel.from_pretrained(self.name) + except Exception as err: raise ValueError( - f"Google AI Platform client failed to initialize. Error: {e}" - ) from e + f"Google AI Platform client failed to initialize. Error: {err}" + ) from err + + return client def __call__(self, docs: List[str]) -> List[List[float]]: """Generates embeddings for the given documents. @@ -90,10 +123,12 @@ class GoogleEncoder(BaseEncoder): docs: A list of strings representing the documents to embed. Returns: - A list of lists, where each inner list contains the embedding values for a document. + A list of lists, where each inner list contains the embedding values for a + document. Raises: - ValueError: If the Google AI Platform client is not initialized or if the API call fails. + ValueError: If the Google AI Platform client is not initialized or if the + API call fails. """ if self.client is None: raise ValueError("Google AI Platform client is not initialized.") diff --git a/semantic_router/encoders/mistral.py b/semantic_router/encoders/mistral.py index 544c629f81dcd2199b0b5b10e704c177c7c144d3..974f11284a162979420ac897474413c266fff1a5 100644 --- a/semantic_router/encoders/mistral.py +++ b/semantic_router/encoders/mistral.py @@ -2,12 +2,12 @@ import os from time import sleep -from typing import List, Optional, Any +from typing import Any, List, Optional +from pydantic.v1 import PrivateAttr from semantic_router.encoders import BaseEncoder from semantic_router.utils.defaults import EncoderDefault -from pydantic.v1 import PrivateAttr class MistralEncoder(BaseEncoder): diff --git a/semantic_router/llms/llamacpp.py b/semantic_router/llms/llamacpp.py index bcf01b6d602fbb6b5be1e743eca75e1b4154da20..5a737b869b229fd36b8bf623e0ef6920423dbbd3 100644 --- a/semantic_router/llms/llamacpp.py +++ b/semantic_router/llms/llamacpp.py @@ -2,12 +2,12 @@ from contextlib import contextmanager from pathlib import Path from typing import Any, Optional +from pydantic.v1 import PrivateAttr + from semantic_router.llms.base import BaseLLM from semantic_router.schema import Message from semantic_router.utils.logger import logger -from pydantic.v1 import PrivateAttr - class LlamaCppLLM(BaseLLM): llm: Any diff --git a/semantic_router/llms/mistral.py b/semantic_router/llms/mistral.py index 647d4073e5c7b50591dd7cd686536c0bebba0714..8ddd1482975f4cb091816bcd197c416ac3e39da7 100644 --- a/semantic_router/llms/mistral.py +++ b/semantic_router/llms/mistral.py @@ -1,14 +1,13 @@ import os -from typing import List, Optional, Any +from typing import Any, List, Optional +from pydantic.v1 import PrivateAttr from semantic_router.llms import BaseLLM from semantic_router.schema import Message from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger -from pydantic.v1 import PrivateAttr - class MistralAILLM(BaseLLM): _client: Any = PrivateAttr() diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 035ca8a00b77ca7b3681d833f5e18c19bff0054b..6b7485ac26255358a0834a763c3df90f27798cca 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -8,9 +8,9 @@ from semantic_router.encoders import ( BaseEncoder, CohereEncoder, FastEmbedEncoder, + GoogleEncoder, MistralEncoder, OpenAIEncoder, - GoogleEncoder, ) diff --git a/tests/unit/encoders/test_google.py b/tests/unit/encoders/test_google.py index 11c19cc81c04890b0ac76da7234df2663bd570a8..42d52b5805ee27fed96d3df83535781ef25b9649 100644 --- a/tests/unit/encoders/test_google.py +++ b/tests/unit/encoders/test_google.py @@ -1,7 +1,7 @@ import pytest +from google.api_core.exceptions import GoogleAPICallError from vertexai.language_models import TextEmbedding from vertexai.language_models._language_models import TextEmbeddingStatistics -from google.api_core.exceptions import GoogleAPICallError from semantic_router.encoders import GoogleEncoder diff --git a/tests/unit/encoders/test_mistral.py b/tests/unit/encoders/test_mistral.py index f36f5037abaab256b03ab912714831b33a23fe2a..25dba6b759bc2a033c64ae0055e0b1b45d626f48 100644 --- a/tests/unit/encoders/test_mistral.py +++ b/tests/unit/encoders/test_mistral.py @@ -1,11 +1,11 @@ +from unittest.mock import patch + import pytest from mistralai.exceptions import MistralException from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, UsageInfo from semantic_router.encoders import MistralEncoder -from unittest.mock import patch - @pytest.fixture def mistralai_encoder(mocker): diff --git a/tests/unit/llms/test_llm_llamacpp.py b/tests/unit/llms/test_llm_llamacpp.py index 63d92ee8dd0ff4dad8b0c60846433e7197ee1699..9f579cdf44713c6eff6dd0c3a5ef8fbcba0e10b5 100644 --- a/tests/unit/llms/test_llm_llamacpp.py +++ b/tests/unit/llms/test_llm_llamacpp.py @@ -1,11 +1,11 @@ +from unittest.mock import patch + import pytest from llama_cpp import Llama from semantic_router.llms.llamacpp import LlamaCppLLM from semantic_router.schema import Message -from unittest.mock import patch - @pytest.fixture def llamacpp_llm(mocker): diff --git a/tests/unit/llms/test_llm_mistral.py b/tests/unit/llms/test_llm_mistral.py index d73406e726c721242743f36e9de062b7460daaf0..3b318e2acb4feb0c9ac723b3c5ea96e0abe40097 100644 --- a/tests/unit/llms/test_llm_mistral.py +++ b/tests/unit/llms/test_llm_mistral.py @@ -1,10 +1,10 @@ +from unittest.mock import patch + import pytest from semantic_router.llms import MistralAILLM from semantic_router.schema import Message -from unittest.mock import patch - @pytest.fixture def mistralai_llm(mocker):