diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py index 2b1262bb1cf263a234cc468f7b42880af1f4b38f..ac9e92797fdd7456e1059ed93c209913fa011b6e 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/llama_index/embeddings/voyageai/base.py @@ -1,14 +1,15 @@ """Voyage embeddings file.""" +import logging from typing import Any, List, Optional from llama_index.core.base.embeddings.base import BaseEmbedding -from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.callbacks.base import CallbackManager import voyageai +from pydantic import PrivateAttr -DEFAULT_VOYAGE_BATCH_SIZE = 8 +logger = logging.getLogger(__name__) class VoyageEmbedding(BaseEmbedding): @@ -22,19 +23,28 @@ class VoyageEmbedding(BaseEmbedding): You can either specify the key here or store it as an environment variable. """ - _model: Any = PrivateAttr() + client: voyageai.Client = PrivateAttr(None) + aclient: voyageai.client_async.AsyncClient = PrivateAttr() + truncation: Optional[bool] = None def __init__( self, - model_name: str = "voyage-01", + model_name: str, voyage_api_key: Optional[str] = None, - embed_batch_size: int = DEFAULT_VOYAGE_BATCH_SIZE, + embed_batch_size: Optional[int] = None, + truncation: Optional[bool] = None, callback_manager: Optional[CallbackManager] = None, **kwargs: Any, ): - if voyage_api_key: - voyageai.api_key = voyage_api_key - self._model = voyageai + if model_name == "voyage-01": + logger.warning( + "voyage-01 is not the latest model by Voyage AI. Please note that `model_name` " + "will be a required argument in the future. We recommend setting it explicitly. Please see " + "https://docs.voyageai.com/docs/embeddings for the latest models offered by Voyage AI." + ) + + if embed_batch_size is None: + embed_batch_size = 72 if model_name in ["voyage-2", "voyage-02"] else 7 super().__init__( model_name=model_name, @@ -43,58 +53,68 @@ class VoyageEmbedding(BaseEmbedding): **kwargs, ) + self.client = voyageai.Client(api_key=voyage_api_key) + self.aclient = voyageai.AsyncClient(api_key=voyage_api_key) + self.truncation = truncation + @classmethod def class_name(cls) -> str: return "VoyageEmbedding" + def _get_embedding(self, texts: List[str], input_type: str) -> List[List[float]]: + return self.client.embed( + texts, + model=self.model_name, + input_type=input_type, + truncation=self.truncation, + ).embeddings + + async def _aget_embedding( + self, texts: List[str], input_type: str + ) -> List[List[float]]: + r = await self.aclient.embed( + texts, + model=self.model_name, + input_type=input_type, + truncation=self.truncation, + ) + return r.embeddings + def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding.""" - return self._model.get_embedding( - query, model=self.model_name, input_type="query" - ) + return self._get_embedding([query], input_type="query")[0] async def _aget_query_embedding(self, query: str) -> List[float]: """The asynchronous version of _get_query_embedding.""" - return await self._model.aget_embedding( - query, model=self.model_name, input_type="query" - ) + r = await self._aget_embedding([query], input_type="query") + return r[0] def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" - return self._model.get_embedding( - text, model=self.model_name, input_type="document" - ) + return self._get_embedding([text], input_type="document")[0] async def _aget_text_embedding(self, text: str) -> List[float]: """Asynchronously get text embedding.""" - return await self._model.aget_embedding( - text, model=self.model_name, input_type="document" - ) + r = await self._aget_embedding([text], input_type="document") + return r[0] def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings.""" - return self._model.get_embeddings( - texts, model=self.model_name, input_type="document" - ) + return self._get_embedding(texts, input_type="document") async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Asynchronously get text embeddings.""" - return await self._model.aget_embeddings( - texts, model=self.model_name, input_type="document" - ) + return await self._aget_embedding(texts, input_type="document") def get_general_text_embedding( self, text: str, input_type: Optional[str] = None ) -> List[float]: """Get general text embedding with input_type.""" - return self._model.get_embedding( - text, model=self.model_name, input_type=input_type - ) + return self._get_embedding([text], input_type=input_type)[0] async def aget_general_text_embedding( self, text: str, input_type: Optional[str] = None ) -> List[float]: """Asynchronously get general text embedding with input_type.""" - return await self._model.aget_embedding( - text, model=self.model_name, input_type=input_type - ) + r = await self._aget_embedding([text], input_type=input_type) + return r[0] diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py index fd2e43e1b1cfc3f88fe2d50f84b85e5d1f3002cc..678016600f37f843816f5558862392dd24e85e88 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-voyageai/tests/test_embeddings_voyageai.py @@ -3,5 +3,32 @@ from llama_index.embeddings.voyageai import VoyageEmbedding def test_embedding_class(): - emb = VoyageEmbedding(model_name="") + emb = VoyageEmbedding(model_name="", voyage_api_key="NOT_A_VALID_KEY") assert isinstance(emb, BaseEmbedding) + assert emb.embed_batch_size == 7 + assert emb.model_name == "" + + +def test_embedding_class_voyage_2(): + emb = VoyageEmbedding( + model_name="voyage-2", voyage_api_key="NOT_A_VALID_KEY", truncation=True + ) + assert isinstance(emb, BaseEmbedding) + assert emb.embed_batch_size == 72 + assert emb.model_name == "voyage-2" + assert emb.truncation + + +def test_embedding_class_voyage_2_with_batch_size(): + emb = VoyageEmbedding( + model_name="voyage-2", voyage_api_key="NOT_A_VALID_KEY", embed_batch_size=49 + ) + assert isinstance(emb, BaseEmbedding) + assert emb.embed_batch_size == 49 + assert emb.model_name == "voyage-2" + assert emb.truncation is None + + +def test_voyageai_embedding_class(): + names_of_base_classes = [b.__name__ for b in VoyageEmbedding.__mro__] + assert BaseEmbedding.__name__ in names_of_base_classes