From d0f91f11d3e1f76c9540012e172c276628abf801 Mon Sep 17 00:00:00 2001
From: EugeneLightsOn <144219719+EugeneLightsOn@users.noreply.github.com>
Date: Fri, 29 Mar 2024 18:11:15 +0200
Subject: [PATCH] Upgrade Llama Cohere client dependencies to version ^5.1.1
 and Llama Cohere LLM updates (#12279)

---
 .../core/evaluation/retrieval/metrics.py      |  2 +-
 .../llama_index/llms/cohere/utils.py          | 36 +++++++++++++++----
 .../llama-index-llms-cohere/pyproject.toml    |  4 +--
 .../postprocessor/cohere_rerank/base.py       |  2 +-
 .../pyproject.toml                            |  4 +--
 5 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py
index f708e63d75..9f8a206cd3 100644
--- a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py
+++ b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py
@@ -119,7 +119,7 @@ class CohereRerankRelevancyMetric(BaseRetrievalMetric):
             query=query,
             documents=retrieved_texts,
         )
-        relevance_scores = [r.relevance_score for r in results]
+        relevance_scores = [r.relevance_score for r in results.results]
         agg_func = self._get_agg_func(agg)
 
         return RetrievalMetricResult(
diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py
index e1c3e24859..b588510cec 100644
--- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py
@@ -49,7 +49,7 @@ def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
         reraise=True,
         stop=stop_after_attempt(max_retries),
         wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
-        retry=(retry_if_exception_type(cohere.error.CohereConnectionError)),
+        retry=(retry_if_exception_type(cohere.errors.ServiceUnavailableError)),
         before_sleep=before_sleep_log(logger, logging.WARNING),
     )
 
@@ -62,10 +62,17 @@ def completion_with_retry(
 
     @retry_decorator
     def _completion_with_retry(**kwargs: Any) -> Any:
+        is_stream = kwargs.pop("stream", False)
         if chat:
-            return client.chat(**kwargs)
+            if is_stream:
+                return client.chat_stream(**kwargs)
+            else:
+                return client.chat(**kwargs)
         else:
-            return client.generate(**kwargs)
+            if is_stream:
+                return client.generate_stream(**kwargs)
+            else:
+                return client.generate(**kwargs)
 
     return _completion_with_retry(**kwargs)
 
@@ -81,10 +88,17 @@ async def acompletion_with_retry(
 
     @retry_decorator
     async def _completion_with_retry(**kwargs: Any) -> Any:
+        is_stream = kwargs.pop("stream", False)
         if chat:
-            return await aclient.chat(**kwargs)
+            if is_stream:
+                return await aclient.chat_stream(**kwargs)
+            else:
+                return await aclient.chat(**kwargs)
         else:
-            return await aclient.generate(**kwargs)
+            if is_stream:
+                return await aclient.generate_stream(**kwargs)
+            else:
+                return await aclient.generate(**kwargs)
 
     return await _completion_with_retry(**kwargs)
 
@@ -107,6 +121,16 @@ def is_chat_model(model: str) -> bool:
 def messages_to_cohere_history(
     messages: Sequence[ChatMessage],
 ) -> List[Dict[str, Optional[str]]]:
+    role_map = {
+        "user": "USER",
+        "system": "SYSTEM",
+        "chatbot": "CHATBOT",
+        "assistant": "CHATBOT",
+        "model": "SYSTEM",
+        "function": "SYSTEM",
+        "tool": "SYSTEM",
+    }
     return [
-        {"user_name": message.role, "message": message.content} for message in messages
+        {"role": role_map[message.role], "message": message.content}
+        for message in messages
     ]
diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml
index 7c89eccdef..4c486980e3 100644
--- a/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml
@@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-cohere"
 readme = "README.md"
-version = "0.1.4"
+version = "0.1.5"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
 llama-index-core = "^0.10.1"
-cohere = ">=4.44"
+cohere = "^5.1.1"
 
 [tool.poetry.group.dev.dependencies]
 ipython = "8.10.0"
diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py
index bf2fff2a2f..21a9b2e631 100644
--- a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py
+++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py
@@ -68,7 +68,7 @@ class CohereRerank(BaseNodePostprocessor):
             )
 
             new_nodes = []
-            for result in results:
+            for result in results.results:
                 new_node_with_score = NodeWithScore(
                     node=nodes[result.index].node, score=result.relevance_score
                 )
diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml
index b090f5913c..c0fa1ea4ca 100644
--- a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml
+++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml
@@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-postprocessor-cohere-rerank"
 readme = "README.md"
-version = "0.1.3"
+version = "0.1.4"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
 llama-index-core = "^0.10.1"
-cohere = ">=4.45"
+cohere = "^5.1.1"
 
 [tool.poetry.group.dev.dependencies]
 ipython = "8.10.0"
-- 
GitLab