From d0f91f11d3e1f76c9540012e172c276628abf801 Mon Sep 17 00:00:00 2001 From: EugeneLightsOn <144219719+EugeneLightsOn@users.noreply.github.com> Date: Fri, 29 Mar 2024 18:11:15 +0200 Subject: [PATCH] Upgrade Llama Cohere client dependencies to version ^5.1.1 and Llama Cohere LLM updates (#12279) --- .../core/evaluation/retrieval/metrics.py | 2 +- .../llama_index/llms/cohere/utils.py | 36 +++++++++++++++---- .../llama-index-llms-cohere/pyproject.toml | 4 +-- .../postprocessor/cohere_rerank/base.py | 2 +- .../pyproject.toml | 4 +-- 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py index f708e63d75..9f8a206cd3 100644 --- a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py +++ b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py @@ -119,7 +119,7 @@ class CohereRerankRelevancyMetric(BaseRetrievalMetric): query=query, documents=retrieved_texts, ) - relevance_scores = [r.relevance_score for r in results] + relevance_scores = [r.relevance_score for r in results.results] agg_func = self._get_agg_func(agg) return RetrievalMetricResult( diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py index e1c3e24859..b588510cec 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py @@ -49,7 +49,7 @@ def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: reraise=True, stop=stop_after_attempt(max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=(retry_if_exception_type(cohere.error.CohereConnectionError)), + retry=(retry_if_exception_type(cohere.errors.ServiceUnavailableError)), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -62,10 +62,17 @@ def completion_with_retry( @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: + is_stream = kwargs.pop("stream", False) if chat: - return client.chat(**kwargs) + if is_stream: + return client.chat_stream(**kwargs) + else: + return client.chat(**kwargs) else: - return client.generate(**kwargs) + if is_stream: + return client.generate_stream(**kwargs) + else: + return client.generate(**kwargs) return _completion_with_retry(**kwargs) @@ -81,10 +88,17 @@ async def acompletion_with_retry( @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: + is_stream = kwargs.pop("stream", False) if chat: - return await aclient.chat(**kwargs) + if is_stream: + return await aclient.chat_stream(**kwargs) + else: + return await aclient.chat(**kwargs) else: - return await aclient.generate(**kwargs) + if is_stream: + return await aclient.generate_stream(**kwargs) + else: + return await aclient.generate(**kwargs) return await _completion_with_retry(**kwargs) @@ -107,6 +121,16 @@ def is_chat_model(model: str) -> bool: def messages_to_cohere_history( messages: Sequence[ChatMessage], ) -> List[Dict[str, Optional[str]]]: + role_map = { + "user": "USER", + "system": "SYSTEM", + "chatbot": "CHATBOT", + "assistant": "CHATBOT", + "model": "SYSTEM", + "function": "SYSTEM", + "tool": "SYSTEM", + } return [ - {"user_name": message.role, "message": message.content} for message in messages + {"role": role_map[message.role], "message": message.content} + for message in messages ] diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml index 7c89eccdef..4c486980e3 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-cohere" readme = "README.md" -version = "0.1.4" +version = "0.1.5" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -cohere = ">=4.44" +cohere = "^5.1.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py index bf2fff2a2f..21a9b2e631 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py @@ -68,7 +68,7 @@ class CohereRerank(BaseNodePostprocessor): ) new_nodes = [] - for result in results: + for result in results.results: new_node_with_score = NodeWithScore( node=nodes[result.index].node, score=result.relevance_score ) diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml index b090f5913c..c0fa1ea4ca 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-postprocessor-cohere-rerank" readme = "README.md" -version = "0.1.3" +version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -cohere = ">=4.45" +cohere = "^5.1.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" -- GitLab