Skip to content
Snippets Groups Projects
Unverified Commit 2fc3c571 authored by billytrend-cohere's avatar billytrend-cohere Committed by GitHub
Browse files

Update cohere and add async support (#12705)

parent 7ef27694
No related branches found
No related tags found
No related merge requests found
from enum import Enum
from typing import Any, List, Optional
from typing import List, Optional
from llama_index.core.base.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
......@@ -7,6 +7,8 @@ from llama_index.core.base.embeddings.base import (
)
from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks import CallbackManager
import cohere
import httpx
# Enums for validation and type safety
......@@ -74,6 +76,15 @@ VALID_MODEL_INPUT_TYPES = {
CAMN.MULTILINGUAL_V2: [None],
}
# v3 models require an input_type field
V3_MODELS = [
CAMN.ENGLISH_V3,
CAMN.ENGLISH_LIGHT_V3,
CAMN.MULTILINGUAL_V3,
CAMN.MULTILINGUAL_LIGHT_V3,
]
# This list would be used for model name and embedding types validation
# Embedding type can be float/ int8/ uint8/ binary/ ubinary based on model.
VALID_MODEL_EMBEDDING_TYPES = {
......@@ -94,7 +105,10 @@ class CohereEmbedding(BaseEmbedding):
"""CohereEmbedding uses the Cohere API to generate embeddings for text."""
# Instance variables initialized via Pydantic's mechanism
cohere_client: Any = Field(description="CohereAI client")
_cohere_client: cohere.Client = Field(description="CohereAI client")
_cohere_async_client: cohere.AsyncClient = Field(
description="CohereAI Async client"
)
truncate: str = Field(description="Truncation type - START/ END/ NONE")
input_type: Optional[str] = Field(
description="Model Input type. If not provided, search_document and search_query are used when needed."
......@@ -112,6 +126,9 @@ class CohereEmbedding(BaseEmbedding):
embedding_type: str = "float",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
base_url: Optional[str] = None,
timeout: Optional[float] = None,
httpx_client: Optional[httpx.AsyncClient] = None,
):
"""
A class representation for generating embeddings using the Cohere API.
......@@ -126,13 +143,6 @@ class CohereEmbedding(BaseEmbedding):
model_name (str): The name of the model to be used for generating embeddings. The class ensures that
this model is supported and that the input type provided is compatible with the model.
"""
try:
import cohere
except ImportError:
raise ImportError(
"`cohere` package not found. Please run `pip install 'cohere>=5.1.1,<6.0.0'."
)
# Validate model_name and input_type
if model_name not in VALID_MODEL_INPUT_TYPES:
raise ValueError(f"{model_name} is not a valid model name")
......@@ -150,7 +160,20 @@ class CohereEmbedding(BaseEmbedding):
raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")
super().__init__(
cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"),
_cohere_client=cohere.Client(
cohere_api_key,
client_name="llama_index",
base_url=base_url,
timeout=timeout,
httpx_client=httpx_client,
),
_cohere_async_client=cohere.AsyncClient(
cohere_api_key,
client_name="llama_index",
base_url=base_url,
timeout=timeout,
httpx_client=httpx_client,
),
cohere_api_key=cohere_api_key,
model_name=model_name,
input_type=input_type,
......@@ -166,13 +189,8 @@ class CohereEmbedding(BaseEmbedding):
def _embed(self, texts: List[str], input_type: str) -> List[List[float]]:
"""Embed sentences using Cohere."""
if self.model_name in [
CAMN.ENGLISH_V3,
CAMN.ENGLISH_LIGHT_V3,
CAMN.MULTILINGUAL_V3,
CAMN.MULTILINGUAL_LIGHT_V3,
]:
result = self.cohere_client.embed(
if self.model_name in V3_MODELS:
result = self._cohere_client.embed(
texts=texts,
input_type=self.input_type or input_type,
embedding_types=[self.embedding_type],
......@@ -180,7 +198,7 @@ class CohereEmbedding(BaseEmbedding):
truncate=self.truncate,
).embeddings
else:
result = self.cohere_client.embed(
result = self._cohere_client.embed(
texts=texts,
model=self.model_name,
embedding_types=[self.embedding_type],
......@@ -188,13 +206,36 @@ class CohereEmbedding(BaseEmbedding):
).embeddings
return getattr(result, self.embedding_type, None)
async def _aembed(self, texts: List[str], input_type: str) -> List[List[float]]:
"""Embed sentences using Cohere."""
if self.model_name in V3_MODELS:
result = (
await self._cohere_async_client.embed(
texts=texts,
input_type=self.input_type or input_type,
embedding_types=[self.embedding_type],
model=self.model_name,
truncate=self.truncate,
)
).embeddings
else:
result = (
await self._cohere_async_client.embed(
texts=texts,
model=self.model_name,
embedding_types=[self.embedding_type],
truncate=self.truncate,
)
).embeddings
return getattr(result, self.embedding_type, None)
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding. For query embeddings, input_type='search_query'."""
return self._embed([query], input_type="search_query")[0]
async def _aget_query_embedding(self, query: str) -> List[float]:
"""Get query embedding async. For query embeddings, input_type='search_query'."""
return self._get_query_embedding(query)
return (await self._aembed([query], input_type="search_query"))[0]
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
......@@ -202,8 +243,12 @@ class CohereEmbedding(BaseEmbedding):
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Get text embedding async."""
return self._get_text_embedding(text)
return (await self._aembed([text], input_type="search_document"))[0]
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
return self._embed(texts, input_type="search_document")
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
return await self._aembed(texts, input_type="search_document")
......@@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-cohere"
readme = "README.md"
version = "0.1.5"
version = "0.1.6"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
cohere = "^5.1.1"
cohere = "^5.2.5"
[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment