diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py index 6c71851da7abe480fef9ad793b9a647940108e21..a40ba92531b49f70f4c3266da7a7cc7c48bda5a9 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, List, Optional +from typing import List, Optional from llama_index.core.base.embeddings.base import ( DEFAULT_EMBED_BATCH_SIZE, @@ -7,6 +7,8 @@ from llama_index.core.base.embeddings.base import ( ) from llama_index.core.bridge.pydantic import Field from llama_index.core.callbacks import CallbackManager +import cohere +import httpx # Enums for validation and type safety @@ -74,6 +76,15 @@ VALID_MODEL_INPUT_TYPES = { CAMN.MULTILINGUAL_V2: [None], } +# v3 models require an input_type field +V3_MODELS = [ + CAMN.ENGLISH_V3, + CAMN.ENGLISH_LIGHT_V3, + CAMN.MULTILINGUAL_V3, + CAMN.MULTILINGUAL_LIGHT_V3, +] + + # This list would be used for model name and embedding types validation # Embedding type can be float/ int8/ uint8/ binary/ ubinary based on model. VALID_MODEL_EMBEDDING_TYPES = { @@ -94,7 +105,10 @@ class CohereEmbedding(BaseEmbedding): """CohereEmbedding uses the Cohere API to generate embeddings for text.""" # Instance variables initialized via Pydantic's mechanism - cohere_client: Any = Field(description="CohereAI client") + _cohere_client: cohere.Client = Field(description="CohereAI client") + _cohere_async_client: cohere.AsyncClient = Field( + description="CohereAI Async client" + ) truncate: str = Field(description="Truncation type - START/ END/ NONE") input_type: Optional[str] = Field( description="Model Input type. If not provided, search_document and search_query are used when needed." @@ -112,6 +126,9 @@ class CohereEmbedding(BaseEmbedding): embedding_type: str = "float", embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, callback_manager: Optional[CallbackManager] = None, + base_url: Optional[str] = None, + timeout: Optional[float] = None, + httpx_client: Optional[httpx.AsyncClient] = None, ): """ A class representation for generating embeddings using the Cohere API. @@ -126,13 +143,6 @@ class CohereEmbedding(BaseEmbedding): model_name (str): The name of the model to be used for generating embeddings. The class ensures that this model is supported and that the input type provided is compatible with the model. """ - try: - import cohere - except ImportError: - raise ImportError( - "`cohere` package not found. Please run `pip install 'cohere>=5.1.1,<6.0.0'." - ) - # Validate model_name and input_type if model_name not in VALID_MODEL_INPUT_TYPES: raise ValueError(f"{model_name} is not a valid model name") @@ -150,7 +160,20 @@ class CohereEmbedding(BaseEmbedding): raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}") super().__init__( - cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"), + _cohere_client=cohere.Client( + cohere_api_key, + client_name="llama_index", + base_url=base_url, + timeout=timeout, + httpx_client=httpx_client, + ), + _cohere_async_client=cohere.AsyncClient( + cohere_api_key, + client_name="llama_index", + base_url=base_url, + timeout=timeout, + httpx_client=httpx_client, + ), cohere_api_key=cohere_api_key, model_name=model_name, input_type=input_type, @@ -166,13 +189,8 @@ class CohereEmbedding(BaseEmbedding): def _embed(self, texts: List[str], input_type: str) -> List[List[float]]: """Embed sentences using Cohere.""" - if self.model_name in [ - CAMN.ENGLISH_V3, - CAMN.ENGLISH_LIGHT_V3, - CAMN.MULTILINGUAL_V3, - CAMN.MULTILINGUAL_LIGHT_V3, - ]: - result = self.cohere_client.embed( + if self.model_name in V3_MODELS: + result = self._cohere_client.embed( texts=texts, input_type=self.input_type or input_type, embedding_types=[self.embedding_type], @@ -180,7 +198,7 @@ class CohereEmbedding(BaseEmbedding): truncate=self.truncate, ).embeddings else: - result = self.cohere_client.embed( + result = self._cohere_client.embed( texts=texts, model=self.model_name, embedding_types=[self.embedding_type], @@ -188,13 +206,36 @@ class CohereEmbedding(BaseEmbedding): ).embeddings return getattr(result, self.embedding_type, None) + async def _aembed(self, texts: List[str], input_type: str) -> List[List[float]]: + """Embed sentences using Cohere.""" + if self.model_name in V3_MODELS: + result = ( + await self._cohere_async_client.embed( + texts=texts, + input_type=self.input_type or input_type, + embedding_types=[self.embedding_type], + model=self.model_name, + truncate=self.truncate, + ) + ).embeddings + else: + result = ( + await self._cohere_async_client.embed( + texts=texts, + model=self.model_name, + embedding_types=[self.embedding_type], + truncate=self.truncate, + ) + ).embeddings + return getattr(result, self.embedding_type, None) + def _get_query_embedding(self, query: str) -> List[float]: """Get query embedding. For query embeddings, input_type='search_query'.""" return self._embed([query], input_type="search_query")[0] async def _aget_query_embedding(self, query: str) -> List[float]: """Get query embedding async. For query embeddings, input_type='search_query'.""" - return self._get_query_embedding(query) + return (await self._aembed([query], input_type="search_query"))[0] def _get_text_embedding(self, text: str) -> List[float]: """Get text embedding.""" @@ -202,8 +243,12 @@ class CohereEmbedding(BaseEmbedding): async def _aget_text_embedding(self, text: str) -> List[float]: """Get text embedding async.""" - return self._get_text_embedding(text) + return (await self._aembed([text], input_type="search_document"))[0] def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: """Get text embeddings.""" return self._embed(texts, input_type="search_document") + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + return await self._aembed(texts, input_type="search_document") diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml index 2166c45862a0494382e20a2cf17f6dcce51ca7ef..cb76ff54d39c342fe2d122949b4498841d2c2254 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-cohere" readme = "README.md" -version = "0.1.5" +version = "0.1.6" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -cohere = "^5.1.1" +cohere = "^5.2.5" [tool.poetry.group.dev.dependencies] ipython = "8.10.0"