diff --git a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py index f708e63d7588ba74547e89cce422b1fb1159c0ca..9f8a206cd38c2e0daa00b8b304b7d26149cd3c49 100644 --- a/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py +++ b/llama-index-core/llama_index/core/evaluation/retrieval/metrics.py @@ -119,7 +119,7 @@ class CohereRerankRelevancyMetric(BaseRetrievalMetric): query=query, documents=retrieved_texts, ) - relevance_scores = [r.relevance_score for r in results] + relevance_scores = [r.relevance_score for r in results.results] agg_func = self._get_agg_func(agg) return RetrievalMetricResult( diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py index e1c3e24859ecbdcf1659ec53b81b58d7087ae5c5..b588510cec1a19e1c1baabb31ee90fcd72bbb9af 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py @@ -49,7 +49,7 @@ def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: reraise=True, stop=stop_after_attempt(max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), - retry=(retry_if_exception_type(cohere.error.CohereConnectionError)), + retry=(retry_if_exception_type(cohere.errors.ServiceUnavailableError)), before_sleep=before_sleep_log(logger, logging.WARNING), ) @@ -62,10 +62,17 @@ def completion_with_retry( @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: + is_stream = kwargs.pop("stream", False) if chat: - return client.chat(**kwargs) + if is_stream: + return client.chat_stream(**kwargs) + else: + return client.chat(**kwargs) else: - return client.generate(**kwargs) + if is_stream: + return client.generate_stream(**kwargs) + else: + return client.generate(**kwargs) return _completion_with_retry(**kwargs) @@ -81,10 +88,17 @@ async def acompletion_with_retry( @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: + is_stream = kwargs.pop("stream", False) if chat: - return await aclient.chat(**kwargs) + if is_stream: + return await aclient.chat_stream(**kwargs) + else: + return await aclient.chat(**kwargs) else: - return await aclient.generate(**kwargs) + if is_stream: + return await aclient.generate_stream(**kwargs) + else: + return await aclient.generate(**kwargs) return await _completion_with_retry(**kwargs) @@ -107,6 +121,16 @@ def is_chat_model(model: str) -> bool: def messages_to_cohere_history( messages: Sequence[ChatMessage], ) -> List[Dict[str, Optional[str]]]: + role_map = { + "user": "USER", + "system": "SYSTEM", + "chatbot": "CHATBOT", + "assistant": "CHATBOT", + "model": "SYSTEM", + "function": "SYSTEM", + "tool": "SYSTEM", + } return [ - {"user_name": message.role, "message": message.content} for message in messages + {"role": role_map[message.role], "message": message.content} + for message in messages ] diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml index 7c89eccdefc7ced5c3a709334bc875eb76b3dd79..4c486980e327388920e55ac2e6b9b0e0ab8b4395 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-cohere/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-cohere" readme = "README.md" -version = "0.1.4" +version = "0.1.5" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -cohere = ">=4.44" +cohere = "^5.1.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0" diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py index bf2fff2a2f6d6d30ddc7a528e6cef30892ae1549..21a9b2e631eadd46cb8c3d3ca517b4eb43dd0e4a 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/llama_index/postprocessor/cohere_rerank/base.py @@ -68,7 +68,7 @@ class CohereRerank(BaseNodePostprocessor): ) new_nodes = [] - for result in results: + for result in results.results: new_node_with_score = NodeWithScore( node=nodes[result.index].node, score=result.relevance_score ) diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml index b090f5913c9f4a7ad49bb962b2aafbd62ffb6dba..c0fa1ea4ca55e5216d8e28154eb80e3e08919a93 100644 --- a/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-cohere-rerank/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-postprocessor-cohere-rerank" readme = "README.md" -version = "0.1.3" +version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -cohere = ">=4.45" +cohere = "^5.1.1" [tool.poetry.group.dev.dependencies] ipython = "8.10.0"