Skip to content
Snippets Groups Projects
Unverified Commit 909c95ed authored by Yujie Qian's avatar Yujie Qian Committed by GitHub
Browse files

Updates to VoyageEmbedding (#11721)


* - refactor the base
 - adding new test case to validate the code works with API key stored as env.variable

* Batch size default value

* Remove unused constant
Correct the code

* Corrections due to comments

* Corrections
New tests

* Adding truncation attribute
Refactoring
Remove useless test

* Remove default model
Add tests
Reformat, etc

---------

Co-authored-by: default avatarfodizoltan <zoltan@conway.expert>
parent d972aebe
Branches
Tags
No related merge requests found
"""Voyage embeddings file.""" """Voyage embeddings file."""
import logging
from typing import Any, List, Optional from typing import Any, List, Optional
from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.callbacks.base import CallbackManager from llama_index.core.callbacks.base import CallbackManager
import voyageai import voyageai
from pydantic import PrivateAttr
DEFAULT_VOYAGE_BATCH_SIZE = 8 logger = logging.getLogger(__name__)
class VoyageEmbedding(BaseEmbedding): class VoyageEmbedding(BaseEmbedding):
...@@ -22,19 +23,28 @@ class VoyageEmbedding(BaseEmbedding): ...@@ -22,19 +23,28 @@ class VoyageEmbedding(BaseEmbedding):
You can either specify the key here or store it as an environment variable. You can either specify the key here or store it as an environment variable.
""" """
_model: Any = PrivateAttr() client: voyageai.Client = PrivateAttr(None)
aclient: voyageai.client_async.AsyncClient = PrivateAttr()
truncation: Optional[bool] = None
def __init__( def __init__(
self, self,
model_name: str = "voyage-01", model_name: str,
voyage_api_key: Optional[str] = None, voyage_api_key: Optional[str] = None,
embed_batch_size: int = DEFAULT_VOYAGE_BATCH_SIZE, embed_batch_size: Optional[int] = None,
truncation: Optional[bool] = None,
callback_manager: Optional[CallbackManager] = None, callback_manager: Optional[CallbackManager] = None,
**kwargs: Any, **kwargs: Any,
): ):
if voyage_api_key: if model_name == "voyage-01":
voyageai.api_key = voyage_api_key logger.warning(
self._model = voyageai "voyage-01 is not the latest model by Voyage AI. Please note that `model_name` "
"will be a required argument in the future. We recommend setting it explicitly. Please see "
"https://docs.voyageai.com/docs/embeddings for the latest models offered by Voyage AI."
)
if embed_batch_size is None:
embed_batch_size = 72 if model_name in ["voyage-2", "voyage-02"] else 7
super().__init__( super().__init__(
model_name=model_name, model_name=model_name,
...@@ -43,58 +53,68 @@ class VoyageEmbedding(BaseEmbedding): ...@@ -43,58 +53,68 @@ class VoyageEmbedding(BaseEmbedding):
**kwargs, **kwargs,
) )
self.client = voyageai.Client(api_key=voyage_api_key)
self.aclient = voyageai.AsyncClient(api_key=voyage_api_key)
self.truncation = truncation
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
return "VoyageEmbedding" return "VoyageEmbedding"
def _get_embedding(self, texts: List[str], input_type: str) -> List[List[float]]:
return self.client.embed(
texts,
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
).embeddings
async def _aget_embedding(
self, texts: List[str], input_type: str
) -> List[List[float]]:
r = await self.aclient.embed(
texts,
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
)
return r.embeddings
def _get_query_embedding(self, query: str) -> List[float]: def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding.""" """Get query embedding."""
return self._model.get_embedding( return self._get_embedding([query], input_type="query")[0]
query, model=self.model_name, input_type="query"
)
async def _aget_query_embedding(self, query: str) -> List[float]: async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding.""" """The asynchronous version of _get_query_embedding."""
return await self._model.aget_embedding( r = await self._aget_embedding([query], input_type="query")
query, model=self.model_name, input_type="query" return r[0]
)
def _get_text_embedding(self, text: str) -> List[float]: def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding.""" """Get text embedding."""
return self._model.get_embedding( return self._get_embedding([text], input_type="document")[0]
text, model=self.model_name, input_type="document"
)
async def _aget_text_embedding(self, text: str) -> List[float]: async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding.""" """Asynchronously get text embedding."""
return await self._model.aget_embedding( r = await self._aget_embedding([text], input_type="document")
text, model=self.model_name, input_type="document" return r[0]
)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings.""" """Get text embeddings."""
return self._model.get_embeddings( return self._get_embedding(texts, input_type="document")
texts, model=self.model_name, input_type="document"
)
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings.""" """Asynchronously get text embeddings."""
return await self._model.aget_embeddings( return await self._aget_embedding(texts, input_type="document")
texts, model=self.model_name, input_type="document"
)
def get_general_text_embedding( def get_general_text_embedding(
self, text: str, input_type: Optional[str] = None self, text: str, input_type: Optional[str] = None
) -> List[float]: ) -> List[float]:
"""Get general text embedding with input_type.""" """Get general text embedding with input_type."""
return self._model.get_embedding( return self._get_embedding([text], input_type=input_type)[0]
text, model=self.model_name, input_type=input_type
)
async def aget_general_text_embedding( async def aget_general_text_embedding(
self, text: str, input_type: Optional[str] = None self, text: str, input_type: Optional[str] = None
) -> List[float]: ) -> List[float]:
"""Asynchronously get general text embedding with input_type.""" """Asynchronously get general text embedding with input_type."""
return await self._model.aget_embedding( r = await self._aget_embedding([text], input_type=input_type)
text, model=self.model_name, input_type=input_type return r[0]
)
...@@ -3,5 +3,32 @@ from llama_index.embeddings.voyageai import VoyageEmbedding ...@@ -3,5 +3,32 @@ from llama_index.embeddings.voyageai import VoyageEmbedding
def test_embedding_class(): def test_embedding_class():
emb = VoyageEmbedding(model_name="") emb = VoyageEmbedding(model_name="", voyage_api_key="NOT_A_VALID_KEY")
assert isinstance(emb, BaseEmbedding) assert isinstance(emb, BaseEmbedding)
assert emb.embed_batch_size == 7
assert emb.model_name == ""
def test_embedding_class_voyage_2():
emb = VoyageEmbedding(
model_name="voyage-2", voyage_api_key="NOT_A_VALID_KEY", truncation=True
)
assert isinstance(emb, BaseEmbedding)
assert emb.embed_batch_size == 72
assert emb.model_name == "voyage-2"
assert emb.truncation
def test_embedding_class_voyage_2_with_batch_size():
emb = VoyageEmbedding(
model_name="voyage-2", voyage_api_key="NOT_A_VALID_KEY", embed_batch_size=49
)
assert isinstance(emb, BaseEmbedding)
assert emb.embed_batch_size == 49
assert emb.model_name == "voyage-2"
assert emb.truncation is None
def test_voyageai_embedding_class():
names_of_base_classes = [b.__name__ for b in VoyageEmbedding.__mro__]
assert BaseEmbedding.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment