Skip to content
Snippets Groups Projects
Unverified Commit 572f9b1b authored by billytrend-cohere's avatar billytrend-cohere Committed by GitHub
Browse files

Revert _ (#12746)

parent 2fc3c571
No related branches found
No related tags found
No related merge requests found
...@@ -105,10 +105,8 @@ class CohereEmbedding(BaseEmbedding): ...@@ -105,10 +105,8 @@ class CohereEmbedding(BaseEmbedding):
"""CohereEmbedding uses the Cohere API to generate embeddings for text.""" """CohereEmbedding uses the Cohere API to generate embeddings for text."""
# Instance variables initialized via Pydantic's mechanism # Instance variables initialized via Pydantic's mechanism
_cohere_client: cohere.Client = Field(description="CohereAI client") cohere_client: cohere.Client = Field(description="CohereAI client")
_cohere_async_client: cohere.AsyncClient = Field( cohere_async_client: cohere.AsyncClient = Field(description="CohereAI Async client")
description="CohereAI Async client"
)
truncate: str = Field(description="Truncation type - START/ END/ NONE") truncate: str = Field(description="Truncation type - START/ END/ NONE")
input_type: Optional[str] = Field( input_type: Optional[str] = Field(
description="Model Input type. If not provided, search_document and search_query are used when needed." description="Model Input type. If not provided, search_document and search_query are used when needed."
...@@ -160,14 +158,14 @@ class CohereEmbedding(BaseEmbedding): ...@@ -160,14 +158,14 @@ class CohereEmbedding(BaseEmbedding):
raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}") raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")
super().__init__( super().__init__(
_cohere_client=cohere.Client( cohere_client=cohere.Client(
cohere_api_key, cohere_api_key,
client_name="llama_index", client_name="llama_index",
base_url=base_url, base_url=base_url,
timeout=timeout, timeout=timeout,
httpx_client=httpx_client, httpx_client=httpx_client,
), ),
_cohere_async_client=cohere.AsyncClient( cohere_async_client=cohere.AsyncClient(
cohere_api_key, cohere_api_key,
client_name="llama_index", client_name="llama_index",
base_url=base_url, base_url=base_url,
...@@ -190,7 +188,7 @@ class CohereEmbedding(BaseEmbedding): ...@@ -190,7 +188,7 @@ class CohereEmbedding(BaseEmbedding):
def _embed(self, texts: List[str], input_type: str) -> List[List[float]]: def _embed(self, texts: List[str], input_type: str) -> List[List[float]]:
"""Embed sentences using Cohere.""" """Embed sentences using Cohere."""
if self.model_name in V3_MODELS: if self.model_name in V3_MODELS:
result = self._cohere_client.embed( result = self.cohere_client.embed(
texts=texts, texts=texts,
input_type=self.input_type or input_type, input_type=self.input_type or input_type,
embedding_types=[self.embedding_type], embedding_types=[self.embedding_type],
...@@ -198,7 +196,7 @@ class CohereEmbedding(BaseEmbedding): ...@@ -198,7 +196,7 @@ class CohereEmbedding(BaseEmbedding):
truncate=self.truncate, truncate=self.truncate,
).embeddings ).embeddings
else: else:
result = self._cohere_client.embed( result = self.cohere_client.embed(
texts=texts, texts=texts,
model=self.model_name, model=self.model_name,
embedding_types=[self.embedding_type], embedding_types=[self.embedding_type],
...@@ -210,7 +208,7 @@ class CohereEmbedding(BaseEmbedding): ...@@ -210,7 +208,7 @@ class CohereEmbedding(BaseEmbedding):
"""Embed sentences using Cohere.""" """Embed sentences using Cohere."""
if self.model_name in V3_MODELS: if self.model_name in V3_MODELS:
result = ( result = (
await self._cohere_async_client.embed( await self.cohere_async_client.embed(
texts=texts, texts=texts,
input_type=self.input_type or input_type, input_type=self.input_type or input_type,
embedding_types=[self.embedding_type], embedding_types=[self.embedding_type],
...@@ -220,7 +218,7 @@ class CohereEmbedding(BaseEmbedding): ...@@ -220,7 +218,7 @@ class CohereEmbedding(BaseEmbedding):
).embeddings ).embeddings
else: else:
result = ( result = (
await self._cohere_async_client.embed( await self.cohere_async_client.embed(
texts=texts, texts=texts,
model=self.model_name, model=self.model_name,
embedding_types=[self.embedding_type], embedding_types=[self.embedding_type],
......
...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"] ...@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-embeddings-cohere" name = "llama-index-embeddings-cohere"
readme = "README.md" readme = "README.md"
version = "0.1.6" version = "0.1.7"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment