diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py index a40ba92531b49f70f4c3266da7a7cc7c48bda5a9..170fe2557a0bb81bfd5d66ed96cf7924910a962e 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py @@ -105,10 +105,8 @@ class CohereEmbedding(BaseEmbedding): """CohereEmbedding uses the Cohere API to generate embeddings for text.""" # Instance variables initialized via Pydantic's mechanism - _cohere_client: cohere.Client = Field(description="CohereAI client") - _cohere_async_client: cohere.AsyncClient = Field( - description="CohereAI Async client" - ) + cohere_client: cohere.Client = Field(description="CohereAI client") + cohere_async_client: cohere.AsyncClient = Field(description="CohereAI Async client") truncate: str = Field(description="Truncation type - START/ END/ NONE") input_type: Optional[str] = Field( description="Model Input type. If not provided, search_document and search_query are used when needed." @@ -160,14 +158,14 @@ class CohereEmbedding(BaseEmbedding): raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}") super().__init__( - _cohere_client=cohere.Client( + cohere_client=cohere.Client( cohere_api_key, client_name="llama_index", base_url=base_url, timeout=timeout, httpx_client=httpx_client, ), - _cohere_async_client=cohere.AsyncClient( + cohere_async_client=cohere.AsyncClient( cohere_api_key, client_name="llama_index", base_url=base_url, @@ -190,7 +188,7 @@ class CohereEmbedding(BaseEmbedding): def _embed(self, texts: List[str], input_type: str) -> List[List[float]]: """Embed sentences using Cohere.""" if self.model_name in V3_MODELS: - result = self._cohere_client.embed( + result = self.cohere_client.embed( texts=texts, input_type=self.input_type or input_type, embedding_types=[self.embedding_type], @@ -198,7 +196,7 @@ class CohereEmbedding(BaseEmbedding): truncate=self.truncate, ).embeddings else: - result = self._cohere_client.embed( + result = self.cohere_client.embed( texts=texts, model=self.model_name, embedding_types=[self.embedding_type], @@ -210,7 +208,7 @@ class CohereEmbedding(BaseEmbedding): """Embed sentences using Cohere.""" if self.model_name in V3_MODELS: result = ( - await self._cohere_async_client.embed( + await self.cohere_async_client.embed( texts=texts, input_type=self.input_type or input_type, embedding_types=[self.embedding_type], @@ -220,7 +218,7 @@ class CohereEmbedding(BaseEmbedding): ).embeddings else: result = ( - await self._cohere_async_client.embed( + await self.cohere_async_client.embed( texts=texts, model=self.model_name, embedding_types=[self.embedding_type], diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml index cb76ff54d39c342fe2d122949b4498841d2c2254..39e3d005f39903145a988797a38a97cb1c45e3b1 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-cohere" readme = "README.md" -version = "0.1.6" +version = "0.1.7" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"