From 572f9b1bdd2f0ad2a95c420849859812b03d57c3 Mon Sep 17 00:00:00 2001 From: billytrend-cohere <144115527+billytrend-cohere@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:35:25 -0500 Subject: [PATCH] Revert _ (#12746) --- .../llama_index/embeddings/cohere/base.py | 18 ++++++++---------- .../pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py index a40ba92531..170fe2557a 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py @@ -105,10 +105,8 @@ class CohereEmbedding(BaseEmbedding): """CohereEmbedding uses the Cohere API to generate embeddings for text.""" # Instance variables initialized via Pydantic's mechanism - _cohere_client: cohere.Client = Field(description="CohereAI client") - _cohere_async_client: cohere.AsyncClient = Field( - description="CohereAI Async client" - ) + cohere_client: cohere.Client = Field(description="CohereAI client") + cohere_async_client: cohere.AsyncClient = Field(description="CohereAI Async client") truncate: str = Field(description="Truncation type - START/ END/ NONE") input_type: Optional[str] = Field( description="Model Input type. If not provided, search_document and search_query are used when needed." @@ -160,14 +158,14 @@ class CohereEmbedding(BaseEmbedding): raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}") super().__init__( - _cohere_client=cohere.Client( + cohere_client=cohere.Client( cohere_api_key, client_name="llama_index", base_url=base_url, timeout=timeout, httpx_client=httpx_client, ), - _cohere_async_client=cohere.AsyncClient( + cohere_async_client=cohere.AsyncClient( cohere_api_key, client_name="llama_index", base_url=base_url, @@ -190,7 +188,7 @@ class CohereEmbedding(BaseEmbedding): def _embed(self, texts: List[str], input_type: str) -> List[List[float]]: """Embed sentences using Cohere.""" if self.model_name in V3_MODELS: - result = self._cohere_client.embed( + result = self.cohere_client.embed( texts=texts, input_type=self.input_type or input_type, embedding_types=[self.embedding_type], @@ -198,7 +196,7 @@ class CohereEmbedding(BaseEmbedding): truncate=self.truncate, ).embeddings else: - result = self._cohere_client.embed( + result = self.cohere_client.embed( texts=texts, model=self.model_name, embedding_types=[self.embedding_type], @@ -210,7 +208,7 @@ class CohereEmbedding(BaseEmbedding): """Embed sentences using Cohere.""" if self.model_name in V3_MODELS: result = ( - await self._cohere_async_client.embed( + await self.cohere_async_client.embed( texts=texts, input_type=self.input_type or input_type, embedding_types=[self.embedding_type], @@ -220,7 +218,7 @@ class CohereEmbedding(BaseEmbedding): ).embeddings else: result = ( - await self._cohere_async_client.embed( + await self.cohere_async_client.embed( texts=texts, model=self.model_name, embedding_types=[self.embedding_type], diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml index cb76ff54d3..39e3d005f3 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-cohere" readme = "README.md" -version = "0.1.6" +version = "0.1.7" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -- GitLab