From 362063cd63130c5a9ff2e83811dbf4d60188a1ba Mon Sep 17 00:00:00 2001 From: Antoine Debugne <146744637+co-antwan@users.noreply.github.com> Date: Thu, 9 May 2024 06:10:45 +0200 Subject: [PATCH] Call Cohere RAG inference with `documents` argument (#13196) --- .../core/prompts/default_prompt_selectors.py | 59 +++++- .../llama_index/llms/cohere/__init__.py | 18 +- .../llama_index/llms/cohere/base.py | 21 +- .../llama_index/llms/cohere/utils.py | 181 +++++++++++++++++- .../tests/test_llms_cohere.py | 78 +++++++- .../tests/test_rag_inference.py | 60 ++++++ 6 files changed, 403 insertions(+), 14 deletions(-) create mode 100644 llama-index-integrations/llms/llama-index-llms-cohere/tests/test_rag_inference.py diff --git a/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py b/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py index 18e2866c7a..23a9d11912 100644 --- a/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py +++ b/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py @@ -14,22 +14,73 @@ from llama_index.core.prompts.default_prompts import ( ) from llama_index.core.prompts.utils import is_chat_model +try: + from llama_index.llms.cohere import ( + is_cohere_model, + COHERE_QA_TEMPLATE, + COHERE_REFINE_TEMPLATE, + COHERE_TREE_SUMMARIZE_TEMPLATE, + COHERE_REFINE_TABLE_CONTEXT_PROMPT, + ) +except ImportError: + COHERE_QA_TEMPLATE = None + COHERE_REFINE_TEMPLATE = None + COHERE_TREE_SUMMARIZE_TEMPLATE = None + COHERE_REFINE_TABLE_CONTEXT_PROMPT = None + +# Define prompt selectors for Text QA, Tree Summarize, Refine, and Refine Table. +# Note: Cohere models accept a special argument `documents` for RAG calls. To pass on retrieved documents to the `documents` argument, +# specialised templates have been defined. The conditionals below ensure that these templates are called by default when a retriever +# is called with a Cohere model for generator. + +# Text QA +default_text_qa_conditionals = [(is_chat_model, CHAT_TEXT_QA_PROMPT)] +if COHERE_QA_TEMPLATE is not None: + default_text_qa_conditionals = [ + (is_cohere_model, COHERE_QA_TEMPLATE), + (is_chat_model, CHAT_TEXT_QA_PROMPT), + ] + DEFAULT_TEXT_QA_PROMPT_SEL = SelectorPromptTemplate( default_template=DEFAULT_TEXT_QA_PROMPT, - conditionals=[(is_chat_model, CHAT_TEXT_QA_PROMPT)], + conditionals=default_text_qa_conditionals, ) +# Tree Summarize +default_tree_summarize_conditionals = [(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT)] +if COHERE_TREE_SUMMARIZE_TEMPLATE is not None: + default_tree_summarize_conditionals = [ + (is_cohere_model, COHERE_TREE_SUMMARIZE_TEMPLATE), + (is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT), + ] + DEFAULT_TREE_SUMMARIZE_PROMPT_SEL = SelectorPromptTemplate( default_template=DEFAULT_TREE_SUMMARIZE_PROMPT, - conditionals=[(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT)], + conditionals=default_tree_summarize_conditionals, ) +# Refine +default_refine_conditionals = [(is_chat_model, CHAT_REFINE_PROMPT)] +if COHERE_REFINE_TEMPLATE is not None: + default_refine_conditionals = [ + (is_cohere_model, COHERE_REFINE_TEMPLATE), + (is_chat_model, CHAT_REFINE_PROMPT), + ] + DEFAULT_REFINE_PROMPT_SEL = SelectorPromptTemplate( default_template=DEFAULT_REFINE_PROMPT, - conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)], + conditionals=default_refine_conditionals, ) +# Refine Table Context +default_refine_table_conditionals = [(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT)] +if COHERE_REFINE_TABLE_CONTEXT_PROMPT is not None: + default_refine_table_conditionals = [ + (is_cohere_model, COHERE_REFINE_TABLE_CONTEXT_PROMPT), + (is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT), + ] + DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL = SelectorPromptTemplate( default_template=DEFAULT_REFINE_TABLE_CONTEXT_PROMPT, - conditionals=[(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT)], + conditionals=default_refine_table_conditionals, ) diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py index 872737738c..6218333755 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py +++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py @@ -1,3 +1,19 @@ from llama_index.llms.cohere.base import Cohere +from llama_index.llms.cohere.utils import ( + COHERE_QA_TEMPLATE, + COHERE_REFINE_TEMPLATE, + COHERE_TREE_SUMMARIZE_TEMPLATE, + COHERE_REFINE_TABLE_CONTEXT_PROMPT, + DocumentMessage, + is_cohere_model, +) -__all__ = ["Cohere"] +__all__ = [ + "COHERE_QA_TEMPLATE", + "COHERE_REFINE_TEMPLATE", + "COHERE_TREE_SUMMARIZE_TEMPLATE", + "COHERE_REFINE_TABLE_CONTEXT_PROMPT", + "DocumentMessage", + "is_cohere_model", + "Cohere", +] diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py index 84e4e11598..f22b70807f 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py +++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py @@ -26,6 +26,7 @@ from llama_index.llms.cohere.utils import ( cohere_modelname_to_contextsize, completion_with_retry, messages_to_cohere_history, + remove_documents_from_messages, ) import cohere @@ -47,7 +48,9 @@ class Cohere(LLM): """ model: str = Field(description="The cohere model to use.") - temperature: float = Field(description="The temperature to use for sampling.") + temperature: float = Field( + description="The temperature to use for sampling.", default=None + ) max_retries: int = Field( default=10, description="The maximum number of API retries." ) @@ -61,9 +64,9 @@ class Cohere(LLM): def __init__( self, - model: str = "command", - temperature: float = 0.5, - max_tokens: int = 512, + model: str = "command-r", + temperature: Optional[float] = None, + max_tokens: Optional[int] = 8192, timeout: Optional[float] = None, max_retries: int = 10, api_key: Optional[str] = None, @@ -130,8 +133,10 @@ class Cohere(LLM): @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - history = messages_to_cohere_history(messages[:-1]) prompt = messages[-1].content + remaining, documents = remove_documents_from_messages(messages[:-1]) + history = messages_to_cohere_history(remaining) + all_kwargs = self._get_all_kwargs(**kwargs) if all_kwargs["model"] not in CHAT_MODELS: raise ValueError(f"{all_kwargs['model']} not supported for chat") @@ -147,6 +152,7 @@ class Cohere(LLM): chat=True, message=prompt, chat_history=history, + documents=documents, **all_kwargs, ) return ChatResponse( @@ -182,8 +188,10 @@ class Cohere(LLM): def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: - history = messages_to_cohere_history(messages[:-1]) prompt = messages[-1].content + remaining, documents = remove_documents_from_messages(messages[:-1]) + history = messages_to_cohere_history(remaining) + all_kwargs = self._get_all_kwargs(**kwargs) all_kwargs["stream"] = True if all_kwargs["model"] not in CHAT_MODELS: @@ -194,6 +202,7 @@ class Cohere(LLM): chat=True, message=prompt, chat_history=history, + documents=documents, **all_kwargs, ) diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py index a95b7a5b3f..f85428059b 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py @@ -1,7 +1,12 @@ +from collections import Counter import logging -from typing import Any, Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +import re -from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.base.llms.base import BaseLLM +from llama_index.core.prompts import ChatPromptTemplate, ChatMessage, MessageRole + +from llama_index.core.prompts.chat_prompts import TEXT_QA_SYSTEM_PROMPT from tenacity import ( before_sleep_log, retry, @@ -30,9 +35,103 @@ REPRESENTATION_MODELS = { ALL_AVAILABLE_MODELS = {**COMMAND_MODELS, **GENERATION_MODELS, **REPRESENTATION_MODELS} CHAT_MODELS = {**COMMAND_MODELS} + +def is_cohere_model(llm: BaseLLM) -> bool: + from llama_index.llms.cohere import Cohere + + return isinstance(llm, Cohere) + + logger = logging.getLogger(__name__) +# Cohere models accept a special argument `documents` for RAG calls. To pass on retrieved documents to the `documents` argument +# as intended by the Cohere team, we define: +# 1. A new ChatMessage class, called DocumentMessage +# 2. Specialised prompt templates for Text QA, Refine, Tree Summarize, and Refine Table that leverage DocumentMessage +# These templates are applied by default when a retriever is called with a Cohere LLM via custom logic inside default_prompt_selectors.py. +# See Cohere.chat for details on how the templates are unpackaged. + + +class DocumentMessage(ChatMessage): + role: MessageRole = MessageRole.SYSTEM + + +# Define new templates with DocumentMessage's to leverage Cohere's `documents` argument +# Text QA +COHERE_QA_TEMPLATE = ChatPromptTemplate( + message_templates=[ + TEXT_QA_SYSTEM_PROMPT, + DocumentMessage(content="{context_str}"), + ChatMessage(content="{query_str}", role=MessageRole.USER), + ] +) +# Refine (based on llama_index.core.chat_prompts::CHAT_REFINE_PROMPT) +REFINE_SYSTEM_PROMPT = ChatMessage( + content=( + "You are an expert Q&A system that strictly operates in two modes " + "when refining existing answers:\n" + "1. **Rewrite** an original answer using the new context.\n" + "2. **Repeat** the original answer if the new context isn't useful.\n" + "Never reference the original answer or context directly in your answer.\n" + "When in doubt, just repeat the original answer.\n" + ), + role=MessageRole.USER, +) +COHERE_REFINE_TEMPLATE = ChatPromptTemplate( + message_templates=[ + REFINE_SYSTEM_PROMPT, + DocumentMessage(content="{context_msg}"), + ChatMessage( + content=( + "Query: {query_str}\n" + "Original Answer: {existing_answer}\n" + "New Answer: " + ), + role=MessageRole.USER, + ), + ] +) +# Tree summarize (based on llama_index.core.chat_prompts::CHAT_TREE_SUMMARIZE_PROMPT) +COHERE_TREE_SUMMARIZE_TEMPLATE = ChatPromptTemplate( + message_templates=[ + TEXT_QA_SYSTEM_PROMPT, + DocumentMessage(content="{context_str}"), + ChatMessage( + content=( + "Given the information from multiple sources and not prior knowledge, " + "answer the query.\n" + "Query: {query_str}\n" + "Answer: " + ), + role=MessageRole.USER, + ), + ] +) +# Table context refine (based on llama_index.core.chat_prompts::CHAT_REFINE_TABLE_CONTEXT_PROMPT) +COHERE_REFINE_TABLE_CONTEXT_PROMPT = ChatPromptTemplate( + message_templates=[ + ChatMessage(content="{query_str}", role=MessageRole.USER), + ChatMessage(content="{existing_answer}", role=MessageRole.ASSISTANT), + DocumentMessage(content="{context_msg}"), + ChatMessage( + content=( + "We have provided a table schema below. " + "---------------------\n" + "{schema}\n" + "---------------------\n" + "We have also provided some context information. " + "Given the context information and the table schema, " + "refine the original answer to better " + "answer the question. " + "If the context isn't useful, return the original answer." + ), + role=MessageRole.USER, + ), + ] +) + + def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]: min_seconds = 4 max_seconds = 10 @@ -119,6 +218,26 @@ def is_chat_model(model: str) -> bool: return model in COMMAND_MODELS +def remove_documents_from_messages( + messages: Sequence[ChatMessage], +) -> Tuple[Sequence[ChatMessage], Optional[List[Dict[str, str]]]]: + """ + Splits messages into two lists: `remaining` and `documents`. + `remaining` contains all messages that aren't of type DocumentMessage (e.g. history, query). + `documents` contains the retrieved documents, formatted as expected by Cohere RAG inference calls. + + NOTE: this will mix turns for multi-turn RAG + """ + documents = [] + remaining = [] + for msg in messages: + if isinstance(msg, DocumentMessage): + documents.append(msg) + else: + remaining.append(msg) + return remaining, messages_to_cohere_documents(documents) + + def messages_to_cohere_history( messages: Sequence[ChatMessage], ) -> List[Dict[str, Optional[str]]]: @@ -135,3 +254,61 @@ def messages_to_cohere_history( {"role": role_map[message.role], "message": message.content} for message in messages ] + + +def messages_to_cohere_documents( + messages: List[DocumentMessage], +) -> Optional[List[Dict[str, str]]]: + """ + Splits out individual documents from `messages` in the format expected by Cohere.chat's `documents`. + Returns None if `messages` is an empty list for compatibility with co.chat (where `documents` is None by default). + """ + if messages == []: + return None + documents = [] + for msg in messages: + documents.extend(document_message_to_cohere_document(msg)) + return documents + + +def document_message_to_cohere_document(message: DocumentMessage) -> Dict: + # By construction, single DocumentMessage contains all retrieved documents + documents: List[Dict[str, str]] = [] + # Capture all key: value pairs. They will be unpacked in separate documents, Cohere-style. + re_known_keywords = re.compile( + r"(file_path|file_name|file_type|file_size|creation_date|last_modified_data|last_accessed_date): (.+)\n+" + ) + # Find most frequent field. We assume that the most frequent field denotes the boundary + # between consecutive documents, and break ties by taking whichever field appears first. + known_keywords = re.findall(re_known_keywords, message.content) + if len(known_keywords) == 0: + # Document doesn't contain expected special fields. Return default formatting. + return [{"text": message.content}] + + fields_counts = Counter([key for key, _ in known_keywords]) + most_frequent = fields_counts.most_common()[0][0] + + # Initialise + document = None + remaining_text = message.content + for key, value in known_keywords: + if key == most_frequent: + # Save current document after extracting text, then reinit `document` to move to next document + if document: # skip first iteration, where document is None + # Extract text up until the most_frequent remaining text, then skip to next document + index = remaining_text.find(key) + document["text"] = remaining_text[:index].strip() + documents.append(document) + document = {} + + # Catch all special fields. Convert them to key: value pairs. + document[key] = value + # Store remaining text, behind the (first instance of) current `value` + remaining_text = remaining_text[ + remaining_text.find(value) + len(value) : + ].strip() + + # Append last document that's in construction + document["text"] = remaining_text.strip() + documents.append(document) + return documents diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py index f18b9558c3..8097adae62 100644 --- a/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py +++ b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py @@ -1,7 +1,83 @@ +from typing import Sequence, Optional, List +from unittest import mock + +import pytest +from cohere import NonStreamedChatResponse + from llama_index.core.base.llms.base import BaseLLM -from llama_index.llms.cohere import Cohere +from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole +from llama_index.core.llms.mock import MockLLM + +from llama_index.llms.cohere import Cohere, DocumentMessage, is_cohere_model + + +def test_is_cohere(): + assert is_cohere_model(Cohere(api_key="mario")) + assert not is_cohere_model(MockLLM()) def test_embedding_class(): names_of_base_classes = [b.__name__ for b in Cohere.__mro__] assert BaseLLM.__name__ in names_of_base_classes + + +@pytest.mark.parametrize( + "messages,expected_chat_history,expected_documents,expected_message", # noqa: PT006 + [ + pytest.param( + [ChatMessage(content="Hello", role=MessageRole.USER)], + [], + None, + "Hello", + id="single user message", + ), + pytest.param( + [ + ChatMessage(content="Earliest message", role=MessageRole.USER), + ChatMessage(content="Latest message", role=MessageRole.USER), + ], + [{"message": "Earliest message", "role": "USER"}], + None, + "Latest message", + id="messages with chat history", + ), + pytest.param( + [ + ChatMessage(content="Earliest message", role=MessageRole.USER), + DocumentMessage(content="Document content"), + ChatMessage(content="Latest message", role=MessageRole.USER), + ], + [{"message": "Earliest message", "role": "USER"}], + [{"text": "Document content"}], + "Latest message", + id="messages with chat history", + ), + ], +) +def test_chat( + messages: Sequence[ChatMessage], + expected_chat_history: Optional[List], + expected_documents: Optional[List], + expected_message: str, +): + # Mock the API client. + with mock.patch("llama_index.llms.cohere.base.cohere.Client", autospec=True): + llm = Cohere(api_key="dummy", temperature=0.3) + # Mock the API response. + llm._client.chat.return_value = NonStreamedChatResponse(text="Placeholder reply") + expected = ChatResponse( + message=ChatMessage(role=MessageRole.ASSISTANT, content="Placeholder reply"), + raw=llm._client.chat.return_value.__dict__, + ) + + actual = llm.chat(messages) + + assert expected == actual + # Assert that the mocked API client was called in the expected way. + llm._client.chat.assert_called_once_with( + chat_history=expected_chat_history, + documents=expected_documents, + message=expected_message, + model="command-r", + temperature=0.3, + ) diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_rag_inference.py b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_rag_inference.py new file mode 100644 index 0000000000..2dbdaf040d --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_rag_inference.py @@ -0,0 +1,60 @@ +import pytest + +from llama_index.core.prompts import MessageRole +from llama_index.llms.cohere import DocumentMessage +from llama_index.llms.cohere.utils import document_message_to_cohere_document + +text1 = "birds flying high" +text2 = "sun in the sky" +text3 = "breeze driftin' on by" +text4 = "fish in the sea" +text5 = "river running free" +texts = [text1, text2, text3, text4, text5] + + +@pytest.mark.parametrize( + "message, expected", # noqa: PT006 + [ + pytest.param( + DocumentMessage( + role=MessageRole.USER, + content="\n\n".join( + [f"file_path: nina.txt\n\n{text}" for text in texts] + ), + additional_kwargs={}, + ), + [{"file_path": "nina.txt", "text": text} for text in texts], + id="single field, multiple documents", + ), + pytest.param( + DocumentMessage( + role=MessageRole.USER, + content="\n\n".join( + [ + f"file_path: nina.txt\n\nfile_name: greatest-hits\n\n{text}" + for text in texts + ] + ), + additional_kwargs={}, + ), + [ + {"file_path": "nina.txt", "file_name": "greatest-hits", "text": text} + for text in texts + ], + id="multiple fields (same count), multiple documents", + ), + pytest.param( + DocumentMessage( + role=MessageRole.USER, + content="\n\n".join(texts), + additional_kwargs={}, + ), + [{"text": "\n\n".join(texts)}], + id="no fields (just text), multiple documents", + ), + ], +) +def test_document_message_to_cohere_document(message, expected): + res = document_message_to_cohere_document(message) + print(res) + assert res == expected -- GitLab