Skip to content
Snippets Groups Projects
Unverified Commit 362063cd authored by Antoine Debugne's avatar Antoine Debugne Committed by GitHub
Browse files

Call Cohere RAG inference with `documents` argument (#13196)

parent dd18a978
No related branches found
No related tags found
No related merge requests found
......@@ -14,22 +14,73 @@ from llama_index.core.prompts.default_prompts import (
)
from llama_index.core.prompts.utils import is_chat_model
try:
from llama_index.llms.cohere import (
is_cohere_model,
COHERE_QA_TEMPLATE,
COHERE_REFINE_TEMPLATE,
COHERE_TREE_SUMMARIZE_TEMPLATE,
COHERE_REFINE_TABLE_CONTEXT_PROMPT,
)
except ImportError:
COHERE_QA_TEMPLATE = None
COHERE_REFINE_TEMPLATE = None
COHERE_TREE_SUMMARIZE_TEMPLATE = None
COHERE_REFINE_TABLE_CONTEXT_PROMPT = None
# Define prompt selectors for Text QA, Tree Summarize, Refine, and Refine Table.
# Note: Cohere models accept a special argument `documents` for RAG calls. To pass on retrieved documents to the `documents` argument,
# specialised templates have been defined. The conditionals below ensure that these templates are called by default when a retriever
# is called with a Cohere model for generator.
# Text QA
default_text_qa_conditionals = [(is_chat_model, CHAT_TEXT_QA_PROMPT)]
if COHERE_QA_TEMPLATE is not None:
default_text_qa_conditionals = [
(is_cohere_model, COHERE_QA_TEMPLATE),
(is_chat_model, CHAT_TEXT_QA_PROMPT),
]
DEFAULT_TEXT_QA_PROMPT_SEL = SelectorPromptTemplate(
default_template=DEFAULT_TEXT_QA_PROMPT,
conditionals=[(is_chat_model, CHAT_TEXT_QA_PROMPT)],
conditionals=default_text_qa_conditionals,
)
# Tree Summarize
default_tree_summarize_conditionals = [(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT)]
if COHERE_TREE_SUMMARIZE_TEMPLATE is not None:
default_tree_summarize_conditionals = [
(is_cohere_model, COHERE_TREE_SUMMARIZE_TEMPLATE),
(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT),
]
DEFAULT_TREE_SUMMARIZE_PROMPT_SEL = SelectorPromptTemplate(
default_template=DEFAULT_TREE_SUMMARIZE_PROMPT,
conditionals=[(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT)],
conditionals=default_tree_summarize_conditionals,
)
# Refine
default_refine_conditionals = [(is_chat_model, CHAT_REFINE_PROMPT)]
if COHERE_REFINE_TEMPLATE is not None:
default_refine_conditionals = [
(is_cohere_model, COHERE_REFINE_TEMPLATE),
(is_chat_model, CHAT_REFINE_PROMPT),
]
DEFAULT_REFINE_PROMPT_SEL = SelectorPromptTemplate(
default_template=DEFAULT_REFINE_PROMPT,
conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)],
conditionals=default_refine_conditionals,
)
# Refine Table Context
default_refine_table_conditionals = [(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT)]
if COHERE_REFINE_TABLE_CONTEXT_PROMPT is not None:
default_refine_table_conditionals = [
(is_cohere_model, COHERE_REFINE_TABLE_CONTEXT_PROMPT),
(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT),
]
DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL = SelectorPromptTemplate(
default_template=DEFAULT_REFINE_TABLE_CONTEXT_PROMPT,
conditionals=[(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT)],
conditionals=default_refine_table_conditionals,
)
from llama_index.llms.cohere.base import Cohere
from llama_index.llms.cohere.utils import (
COHERE_QA_TEMPLATE,
COHERE_REFINE_TEMPLATE,
COHERE_TREE_SUMMARIZE_TEMPLATE,
COHERE_REFINE_TABLE_CONTEXT_PROMPT,
DocumentMessage,
is_cohere_model,
)
__all__ = ["Cohere"]
__all__ = [
"COHERE_QA_TEMPLATE",
"COHERE_REFINE_TEMPLATE",
"COHERE_TREE_SUMMARIZE_TEMPLATE",
"COHERE_REFINE_TABLE_CONTEXT_PROMPT",
"DocumentMessage",
"is_cohere_model",
"Cohere",
]
......@@ -26,6 +26,7 @@ from llama_index.llms.cohere.utils import (
cohere_modelname_to_contextsize,
completion_with_retry,
messages_to_cohere_history,
remove_documents_from_messages,
)
import cohere
......@@ -47,7 +48,9 @@ class Cohere(LLM):
"""
model: str = Field(description="The cohere model to use.")
temperature: float = Field(description="The temperature to use for sampling.")
temperature: float = Field(
description="The temperature to use for sampling.", default=None
)
max_retries: int = Field(
default=10, description="The maximum number of API retries."
)
......@@ -61,9 +64,9 @@ class Cohere(LLM):
def __init__(
self,
model: str = "command",
temperature: float = 0.5,
max_tokens: int = 512,
model: str = "command-r",
temperature: Optional[float] = None,
max_tokens: Optional[int] = 8192,
timeout: Optional[float] = None,
max_retries: int = 10,
api_key: Optional[str] = None,
......@@ -130,8 +133,10 @@ class Cohere(LLM):
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
history = messages_to_cohere_history(messages[:-1])
prompt = messages[-1].content
remaining, documents = remove_documents_from_messages(messages[:-1])
history = messages_to_cohere_history(remaining)
all_kwargs = self._get_all_kwargs(**kwargs)
if all_kwargs["model"] not in CHAT_MODELS:
raise ValueError(f"{all_kwargs['model']} not supported for chat")
......@@ -147,6 +152,7 @@ class Cohere(LLM):
chat=True,
message=prompt,
chat_history=history,
documents=documents,
**all_kwargs,
)
return ChatResponse(
......@@ -182,8 +188,10 @@ class Cohere(LLM):
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
history = messages_to_cohere_history(messages[:-1])
prompt = messages[-1].content
remaining, documents = remove_documents_from_messages(messages[:-1])
history = messages_to_cohere_history(remaining)
all_kwargs = self._get_all_kwargs(**kwargs)
all_kwargs["stream"] = True
if all_kwargs["model"] not in CHAT_MODELS:
......@@ -194,6 +202,7 @@ class Cohere(LLM):
chat=True,
message=prompt,
chat_history=history,
documents=documents,
**all_kwargs,
)
......
from collections import Counter
import logging
from typing import Any, Callable, Dict, List, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
import re
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.prompts import ChatPromptTemplate, ChatMessage, MessageRole
from llama_index.core.prompts.chat_prompts import TEXT_QA_SYSTEM_PROMPT
from tenacity import (
before_sleep_log,
retry,
......@@ -30,9 +35,103 @@ REPRESENTATION_MODELS = {
ALL_AVAILABLE_MODELS = {**COMMAND_MODELS, **GENERATION_MODELS, **REPRESENTATION_MODELS}
CHAT_MODELS = {**COMMAND_MODELS}
def is_cohere_model(llm: BaseLLM) -> bool:
from llama_index.llms.cohere import Cohere
return isinstance(llm, Cohere)
logger = logging.getLogger(__name__)
# Cohere models accept a special argument `documents` for RAG calls. To pass on retrieved documents to the `documents` argument
# as intended by the Cohere team, we define:
# 1. A new ChatMessage class, called DocumentMessage
# 2. Specialised prompt templates for Text QA, Refine, Tree Summarize, and Refine Table that leverage DocumentMessage
# These templates are applied by default when a retriever is called with a Cohere LLM via custom logic inside default_prompt_selectors.py.
# See Cohere.chat for details on how the templates are unpackaged.
class DocumentMessage(ChatMessage):
role: MessageRole = MessageRole.SYSTEM
# Define new templates with DocumentMessage's to leverage Cohere's `documents` argument
# Text QA
COHERE_QA_TEMPLATE = ChatPromptTemplate(
message_templates=[
TEXT_QA_SYSTEM_PROMPT,
DocumentMessage(content="{context_str}"),
ChatMessage(content="{query_str}", role=MessageRole.USER),
]
)
# Refine (based on llama_index.core.chat_prompts::CHAT_REFINE_PROMPT)
REFINE_SYSTEM_PROMPT = ChatMessage(
content=(
"You are an expert Q&A system that strictly operates in two modes "
"when refining existing answers:\n"
"1. **Rewrite** an original answer using the new context.\n"
"2. **Repeat** the original answer if the new context isn't useful.\n"
"Never reference the original answer or context directly in your answer.\n"
"When in doubt, just repeat the original answer.\n"
),
role=MessageRole.USER,
)
COHERE_REFINE_TEMPLATE = ChatPromptTemplate(
message_templates=[
REFINE_SYSTEM_PROMPT,
DocumentMessage(content="{context_msg}"),
ChatMessage(
content=(
"Query: {query_str}\n"
"Original Answer: {existing_answer}\n"
"New Answer: "
),
role=MessageRole.USER,
),
]
)
# Tree summarize (based on llama_index.core.chat_prompts::CHAT_TREE_SUMMARIZE_PROMPT)
COHERE_TREE_SUMMARIZE_TEMPLATE = ChatPromptTemplate(
message_templates=[
TEXT_QA_SYSTEM_PROMPT,
DocumentMessage(content="{context_str}"),
ChatMessage(
content=(
"Given the information from multiple sources and not prior knowledge, "
"answer the query.\n"
"Query: {query_str}\n"
"Answer: "
),
role=MessageRole.USER,
),
]
)
# Table context refine (based on llama_index.core.chat_prompts::CHAT_REFINE_TABLE_CONTEXT_PROMPT)
COHERE_REFINE_TABLE_CONTEXT_PROMPT = ChatPromptTemplate(
message_templates=[
ChatMessage(content="{query_str}", role=MessageRole.USER),
ChatMessage(content="{existing_answer}", role=MessageRole.ASSISTANT),
DocumentMessage(content="{context_msg}"),
ChatMessage(
content=(
"We have provided a table schema below. "
"---------------------\n"
"{schema}\n"
"---------------------\n"
"We have also provided some context information. "
"Given the context information and the table schema, "
"refine the original answer to better "
"answer the question. "
"If the context isn't useful, return the original answer."
),
role=MessageRole.USER,
),
]
)
def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
min_seconds = 4
max_seconds = 10
......@@ -119,6 +218,26 @@ def is_chat_model(model: str) -> bool:
return model in COMMAND_MODELS
def remove_documents_from_messages(
messages: Sequence[ChatMessage],
) -> Tuple[Sequence[ChatMessage], Optional[List[Dict[str, str]]]]:
"""
Splits messages into two lists: `remaining` and `documents`.
`remaining` contains all messages that aren't of type DocumentMessage (e.g. history, query).
`documents` contains the retrieved documents, formatted as expected by Cohere RAG inference calls.
NOTE: this will mix turns for multi-turn RAG
"""
documents = []
remaining = []
for msg in messages:
if isinstance(msg, DocumentMessage):
documents.append(msg)
else:
remaining.append(msg)
return remaining, messages_to_cohere_documents(documents)
def messages_to_cohere_history(
messages: Sequence[ChatMessage],
) -> List[Dict[str, Optional[str]]]:
......@@ -135,3 +254,61 @@ def messages_to_cohere_history(
{"role": role_map[message.role], "message": message.content}
for message in messages
]
def messages_to_cohere_documents(
messages: List[DocumentMessage],
) -> Optional[List[Dict[str, str]]]:
"""
Splits out individual documents from `messages` in the format expected by Cohere.chat's `documents`.
Returns None if `messages` is an empty list for compatibility with co.chat (where `documents` is None by default).
"""
if messages == []:
return None
documents = []
for msg in messages:
documents.extend(document_message_to_cohere_document(msg))
return documents
def document_message_to_cohere_document(message: DocumentMessage) -> Dict:
# By construction, single DocumentMessage contains all retrieved documents
documents: List[Dict[str, str]] = []
# Capture all key: value pairs. They will be unpacked in separate documents, Cohere-style.
re_known_keywords = re.compile(
r"(file_path|file_name|file_type|file_size|creation_date|last_modified_data|last_accessed_date): (.+)\n+"
)
# Find most frequent field. We assume that the most frequent field denotes the boundary
# between consecutive documents, and break ties by taking whichever field appears first.
known_keywords = re.findall(re_known_keywords, message.content)
if len(known_keywords) == 0:
# Document doesn't contain expected special fields. Return default formatting.
return [{"text": message.content}]
fields_counts = Counter([key for key, _ in known_keywords])
most_frequent = fields_counts.most_common()[0][0]
# Initialise
document = None
remaining_text = message.content
for key, value in known_keywords:
if key == most_frequent:
# Save current document after extracting text, then reinit `document` to move to next document
if document: # skip first iteration, where document is None
# Extract text up until the most_frequent remaining text, then skip to next document
index = remaining_text.find(key)
document["text"] = remaining_text[:index].strip()
documents.append(document)
document = {}
# Catch all special fields. Convert them to key: value pairs.
document[key] = value
# Store remaining text, behind the (first instance of) current `value`
remaining_text = remaining_text[
remaining_text.find(value) + len(value) :
].strip()
# Append last document that's in construction
document["text"] = remaining_text.strip()
documents.append(document)
return documents
from typing import Sequence, Optional, List
from unittest import mock
import pytest
from cohere import NonStreamedChatResponse
from llama_index.core.base.llms.base import BaseLLM
from llama_index.llms.cohere import Cohere
from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole
from llama_index.core.llms.mock import MockLLM
from llama_index.llms.cohere import Cohere, DocumentMessage, is_cohere_model
def test_is_cohere():
assert is_cohere_model(Cohere(api_key="mario"))
assert not is_cohere_model(MockLLM())
def test_embedding_class():
names_of_base_classes = [b.__name__ for b in Cohere.__mro__]
assert BaseLLM.__name__ in names_of_base_classes
@pytest.mark.parametrize(
"messages,expected_chat_history,expected_documents,expected_message", # noqa: PT006
[
pytest.param(
[ChatMessage(content="Hello", role=MessageRole.USER)],
[],
None,
"Hello",
id="single user message",
),
pytest.param(
[
ChatMessage(content="Earliest message", role=MessageRole.USER),
ChatMessage(content="Latest message", role=MessageRole.USER),
],
[{"message": "Earliest message", "role": "USER"}],
None,
"Latest message",
id="messages with chat history",
),
pytest.param(
[
ChatMessage(content="Earliest message", role=MessageRole.USER),
DocumentMessage(content="Document content"),
ChatMessage(content="Latest message", role=MessageRole.USER),
],
[{"message": "Earliest message", "role": "USER"}],
[{"text": "Document content"}],
"Latest message",
id="messages with chat history",
),
],
)
def test_chat(
messages: Sequence[ChatMessage],
expected_chat_history: Optional[List],
expected_documents: Optional[List],
expected_message: str,
):
# Mock the API client.
with mock.patch("llama_index.llms.cohere.base.cohere.Client", autospec=True):
llm = Cohere(api_key="dummy", temperature=0.3)
# Mock the API response.
llm._client.chat.return_value = NonStreamedChatResponse(text="Placeholder reply")
expected = ChatResponse(
message=ChatMessage(role=MessageRole.ASSISTANT, content="Placeholder reply"),
raw=llm._client.chat.return_value.__dict__,
)
actual = llm.chat(messages)
assert expected == actual
# Assert that the mocked API client was called in the expected way.
llm._client.chat.assert_called_once_with(
chat_history=expected_chat_history,
documents=expected_documents,
message=expected_message,
model="command-r",
temperature=0.3,
)
import pytest
from llama_index.core.prompts import MessageRole
from llama_index.llms.cohere import DocumentMessage
from llama_index.llms.cohere.utils import document_message_to_cohere_document
text1 = "birds flying high"
text2 = "sun in the sky"
text3 = "breeze driftin' on by"
text4 = "fish in the sea"
text5 = "river running free"
texts = [text1, text2, text3, text4, text5]
@pytest.mark.parametrize(
"message, expected", # noqa: PT006
[
pytest.param(
DocumentMessage(
role=MessageRole.USER,
content="\n\n".join(
[f"file_path: nina.txt\n\n{text}" for text in texts]
),
additional_kwargs={},
),
[{"file_path": "nina.txt", "text": text} for text in texts],
id="single field, multiple documents",
),
pytest.param(
DocumentMessage(
role=MessageRole.USER,
content="\n\n".join(
[
f"file_path: nina.txt\n\nfile_name: greatest-hits\n\n{text}"
for text in texts
]
),
additional_kwargs={},
),
[
{"file_path": "nina.txt", "file_name": "greatest-hits", "text": text}
for text in texts
],
id="multiple fields (same count), multiple documents",
),
pytest.param(
DocumentMessage(
role=MessageRole.USER,
content="\n\n".join(texts),
additional_kwargs={},
),
[{"text": "\n\n".join(texts)}],
id="no fields (just text), multiple documents",
),
],
)
def test_document_message_to_cohere_document(message, expected):
res = document_message_to_cohere_document(message)
print(res)
assert res == expected
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment