From 2fc3c57157356ef8afdf0939eb774d40d280182c Mon Sep 17 00:00:00 2001
From: billytrend-cohere <144115527+billytrend-cohere@users.noreply.github.com>
Date: Thu, 11 Apr 2024 11:22:55 -0500
Subject: [PATCH] Update cohere and add async support (#12705)

---
 .../llama_index/embeddings/cohere/base.py     | 85 ++++++++++++++-----
 .../pyproject.toml                            |  4 +-
 2 files changed, 67 insertions(+), 22 deletions(-)

diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py
index 6c71851da7..a40ba92531 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/llama_index/embeddings/cohere/base.py
@@ -1,5 +1,5 @@
 from enum import Enum
-from typing import Any, List, Optional
+from typing import List, Optional
 
 from llama_index.core.base.embeddings.base import (
     DEFAULT_EMBED_BATCH_SIZE,
@@ -7,6 +7,8 @@ from llama_index.core.base.embeddings.base import (
 )
 from llama_index.core.bridge.pydantic import Field
 from llama_index.core.callbacks import CallbackManager
+import cohere
+import httpx
 
 
 # Enums for validation and type safety
@@ -74,6 +76,15 @@ VALID_MODEL_INPUT_TYPES = {
     CAMN.MULTILINGUAL_V2: [None],
 }
 
+# v3 models require an input_type field
+V3_MODELS = [
+    CAMN.ENGLISH_V3,
+    CAMN.ENGLISH_LIGHT_V3,
+    CAMN.MULTILINGUAL_V3,
+    CAMN.MULTILINGUAL_LIGHT_V3,
+]
+
+
 # This list would be used for model name and embedding types validation
 # Embedding type can be float/ int8/ uint8/ binary/ ubinary based on model.
 VALID_MODEL_EMBEDDING_TYPES = {
@@ -94,7 +105,10 @@ class CohereEmbedding(BaseEmbedding):
     """CohereEmbedding uses the Cohere API to generate embeddings for text."""
 
     # Instance variables initialized via Pydantic's mechanism
-    cohere_client: Any = Field(description="CohereAI client")
+    _cohere_client: cohere.Client = Field(description="CohereAI client")
+    _cohere_async_client: cohere.AsyncClient = Field(
+        description="CohereAI Async client"
+    )
     truncate: str = Field(description="Truncation type - START/ END/ NONE")
     input_type: Optional[str] = Field(
         description="Model Input type. If not provided, search_document and search_query are used when needed."
@@ -112,6 +126,9 @@ class CohereEmbedding(BaseEmbedding):
         embedding_type: str = "float",
         embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
         callback_manager: Optional[CallbackManager] = None,
+        base_url: Optional[str] = None,
+        timeout: Optional[float] = None,
+        httpx_client: Optional[httpx.AsyncClient] = None,
     ):
         """
         A class representation for generating embeddings using the Cohere API.
@@ -126,13 +143,6 @@ class CohereEmbedding(BaseEmbedding):
             model_name (str): The name of the model to be used for generating embeddings. The class ensures that
                           this model is supported and that the input type provided is compatible with the model.
         """
-        try:
-            import cohere
-        except ImportError:
-            raise ImportError(
-                "`cohere` package not found. Please run `pip install 'cohere>=5.1.1,<6.0.0'."
-            )
-
         # Validate model_name and input_type
         if model_name not in VALID_MODEL_INPUT_TYPES:
             raise ValueError(f"{model_name} is not a valid model name")
@@ -150,7 +160,20 @@ class CohereEmbedding(BaseEmbedding):
             raise ValueError(f"truncate must be one of {VALID_TRUNCATE_OPTIONS}")
 
         super().__init__(
-            cohere_client=cohere.Client(cohere_api_key, client_name="llama_index"),
+            _cohere_client=cohere.Client(
+                cohere_api_key,
+                client_name="llama_index",
+                base_url=base_url,
+                timeout=timeout,
+                httpx_client=httpx_client,
+            ),
+            _cohere_async_client=cohere.AsyncClient(
+                cohere_api_key,
+                client_name="llama_index",
+                base_url=base_url,
+                timeout=timeout,
+                httpx_client=httpx_client,
+            ),
             cohere_api_key=cohere_api_key,
             model_name=model_name,
             input_type=input_type,
@@ -166,13 +189,8 @@ class CohereEmbedding(BaseEmbedding):
 
     def _embed(self, texts: List[str], input_type: str) -> List[List[float]]:
         """Embed sentences using Cohere."""
-        if self.model_name in [
-            CAMN.ENGLISH_V3,
-            CAMN.ENGLISH_LIGHT_V3,
-            CAMN.MULTILINGUAL_V3,
-            CAMN.MULTILINGUAL_LIGHT_V3,
-        ]:
-            result = self.cohere_client.embed(
+        if self.model_name in V3_MODELS:
+            result = self._cohere_client.embed(
                 texts=texts,
                 input_type=self.input_type or input_type,
                 embedding_types=[self.embedding_type],
@@ -180,7 +198,7 @@ class CohereEmbedding(BaseEmbedding):
                 truncate=self.truncate,
             ).embeddings
         else:
-            result = self.cohere_client.embed(
+            result = self._cohere_client.embed(
                 texts=texts,
                 model=self.model_name,
                 embedding_types=[self.embedding_type],
@@ -188,13 +206,36 @@ class CohereEmbedding(BaseEmbedding):
             ).embeddings
         return getattr(result, self.embedding_type, None)
 
+    async def _aembed(self, texts: List[str], input_type: str) -> List[List[float]]:
+        """Embed sentences using Cohere."""
+        if self.model_name in V3_MODELS:
+            result = (
+                await self._cohere_async_client.embed(
+                    texts=texts,
+                    input_type=self.input_type or input_type,
+                    embedding_types=[self.embedding_type],
+                    model=self.model_name,
+                    truncate=self.truncate,
+                )
+            ).embeddings
+        else:
+            result = (
+                await self._cohere_async_client.embed(
+                    texts=texts,
+                    model=self.model_name,
+                    embedding_types=[self.embedding_type],
+                    truncate=self.truncate,
+                )
+            ).embeddings
+        return getattr(result, self.embedding_type, None)
+
     def _get_query_embedding(self, query: str) -> List[float]:
         """Get query embedding. For query embeddings, input_type='search_query'."""
         return self._embed([query], input_type="search_query")[0]
 
     async def _aget_query_embedding(self, query: str) -> List[float]:
         """Get query embedding async. For query embeddings, input_type='search_query'."""
-        return self._get_query_embedding(query)
+        return (await self._aembed([query], input_type="search_query"))[0]
 
     def _get_text_embedding(self, text: str) -> List[float]:
         """Get text embedding."""
@@ -202,8 +243,12 @@ class CohereEmbedding(BaseEmbedding):
 
     async def _aget_text_embedding(self, text: str) -> List[float]:
         """Get text embedding async."""
-        return self._get_text_embedding(text)
+        return (await self._aembed([text], input_type="search_document"))[0]
 
     def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
         """Get text embeddings."""
         return self._embed(texts, input_type="search_document")
+
+    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
+        """Get text embeddings."""
+        return await self._aembed(texts, input_type="search_document")
diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml
index 2166c45862..cb76ff54d3 100644
--- a/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml
+++ b/llama-index-integrations/embeddings/llama-index-embeddings-cohere/pyproject.toml
@@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-embeddings-cohere"
 readme = "README.md"
-version = "0.1.5"
+version = "0.1.6"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
 llama-index-core = "^0.10.1"
-cohere = "^5.1.1"
+cohere = "^5.2.5"
 
 [tool.poetry.group.dev.dependencies]
 ipython = "8.10.0"
-- 
GitLab