diff --git a/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py b/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py
index 18e2866c7a7431d6d09d6ee29f4c101f281cb4df..23a9d11912ca96edfe7d9266428c376da32d23cb 100644
--- a/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py
+++ b/llama-index-core/llama_index/core/prompts/default_prompt_selectors.py
@@ -14,22 +14,73 @@ from llama_index.core.prompts.default_prompts import (
 )
 from llama_index.core.prompts.utils import is_chat_model
 
+try:
+    from llama_index.llms.cohere import (
+        is_cohere_model,
+        COHERE_QA_TEMPLATE,
+        COHERE_REFINE_TEMPLATE,
+        COHERE_TREE_SUMMARIZE_TEMPLATE,
+        COHERE_REFINE_TABLE_CONTEXT_PROMPT,
+    )
+except ImportError:
+    COHERE_QA_TEMPLATE = None
+    COHERE_REFINE_TEMPLATE = None
+    COHERE_TREE_SUMMARIZE_TEMPLATE = None
+    COHERE_REFINE_TABLE_CONTEXT_PROMPT = None
+
+# Define prompt selectors for Text QA, Tree Summarize, Refine, and Refine Table.
+# Note: Cohere models accept a special argument `documents` for RAG calls. To pass on retrieved documents to the `documents` argument,
+# specialised templates have been defined. The conditionals below ensure that these templates are called by default when a retriever
+# is called with a Cohere model for generator.
+
+# Text QA
+default_text_qa_conditionals = [(is_chat_model, CHAT_TEXT_QA_PROMPT)]
+if COHERE_QA_TEMPLATE is not None:
+    default_text_qa_conditionals = [
+        (is_cohere_model, COHERE_QA_TEMPLATE),
+        (is_chat_model, CHAT_TEXT_QA_PROMPT),
+    ]
+
 DEFAULT_TEXT_QA_PROMPT_SEL = SelectorPromptTemplate(
     default_template=DEFAULT_TEXT_QA_PROMPT,
-    conditionals=[(is_chat_model, CHAT_TEXT_QA_PROMPT)],
+    conditionals=default_text_qa_conditionals,
 )
 
+# Tree Summarize
+default_tree_summarize_conditionals = [(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT)]
+if COHERE_TREE_SUMMARIZE_TEMPLATE is not None:
+    default_tree_summarize_conditionals = [
+        (is_cohere_model, COHERE_TREE_SUMMARIZE_TEMPLATE),
+        (is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT),
+    ]
+
 DEFAULT_TREE_SUMMARIZE_PROMPT_SEL = SelectorPromptTemplate(
     default_template=DEFAULT_TREE_SUMMARIZE_PROMPT,
-    conditionals=[(is_chat_model, CHAT_TREE_SUMMARIZE_PROMPT)],
+    conditionals=default_tree_summarize_conditionals,
 )
 
+# Refine
+default_refine_conditionals = [(is_chat_model, CHAT_REFINE_PROMPT)]
+if COHERE_REFINE_TEMPLATE is not None:
+    default_refine_conditionals = [
+        (is_cohere_model, COHERE_REFINE_TEMPLATE),
+        (is_chat_model, CHAT_REFINE_PROMPT),
+    ]
+
 DEFAULT_REFINE_PROMPT_SEL = SelectorPromptTemplate(
     default_template=DEFAULT_REFINE_PROMPT,
-    conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)],
+    conditionals=default_refine_conditionals,
 )
 
+# Refine Table Context
+default_refine_table_conditionals = [(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT)]
+if COHERE_REFINE_TABLE_CONTEXT_PROMPT is not None:
+    default_refine_table_conditionals = [
+        (is_cohere_model, COHERE_REFINE_TABLE_CONTEXT_PROMPT),
+        (is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT),
+    ]
+
 DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL = SelectorPromptTemplate(
     default_template=DEFAULT_REFINE_TABLE_CONTEXT_PROMPT,
-    conditionals=[(is_chat_model, CHAT_REFINE_TABLE_CONTEXT_PROMPT)],
+    conditionals=default_refine_table_conditionals,
 )
diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py
index 872737738c4609bb3a21122e0a6b46fba5106a53..6218333755b1640620808495c9dc59e55485ea3a 100644
--- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py
+++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/__init__.py
@@ -1,3 +1,19 @@
 from llama_index.llms.cohere.base import Cohere
+from llama_index.llms.cohere.utils import (
+    COHERE_QA_TEMPLATE,
+    COHERE_REFINE_TEMPLATE,
+    COHERE_TREE_SUMMARIZE_TEMPLATE,
+    COHERE_REFINE_TABLE_CONTEXT_PROMPT,
+    DocumentMessage,
+    is_cohere_model,
+)
 
-__all__ = ["Cohere"]
+__all__ = [
+    "COHERE_QA_TEMPLATE",
+    "COHERE_REFINE_TEMPLATE",
+    "COHERE_TREE_SUMMARIZE_TEMPLATE",
+    "COHERE_REFINE_TABLE_CONTEXT_PROMPT",
+    "DocumentMessage",
+    "is_cohere_model",
+    "Cohere",
+]
diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py
index 84e4e11598b7283a3c003b843b279efc88770454..f22b70807ff406932e5507ca1dc61875f2d12570 100644
--- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/base.py
@@ -26,6 +26,7 @@ from llama_index.llms.cohere.utils import (
     cohere_modelname_to_contextsize,
     completion_with_retry,
     messages_to_cohere_history,
+    remove_documents_from_messages,
 )
 
 import cohere
@@ -47,7 +48,9 @@ class Cohere(LLM):
     """
 
     model: str = Field(description="The cohere model to use.")
-    temperature: float = Field(description="The temperature to use for sampling.")
+    temperature: float = Field(
+        description="The temperature to use for sampling.", default=None
+    )
     max_retries: int = Field(
         default=10, description="The maximum number of API retries."
     )
@@ -61,9 +64,9 @@ class Cohere(LLM):
 
     def __init__(
         self,
-        model: str = "command",
-        temperature: float = 0.5,
-        max_tokens: int = 512,
+        model: str = "command-r",
+        temperature: Optional[float] = None,
+        max_tokens: Optional[int] = 8192,
         timeout: Optional[float] = None,
         max_retries: int = 10,
         api_key: Optional[str] = None,
@@ -130,8 +133,10 @@ class Cohere(LLM):
 
     @llm_chat_callback()
     def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
-        history = messages_to_cohere_history(messages[:-1])
         prompt = messages[-1].content
+        remaining, documents = remove_documents_from_messages(messages[:-1])
+        history = messages_to_cohere_history(remaining)
+
         all_kwargs = self._get_all_kwargs(**kwargs)
         if all_kwargs["model"] not in CHAT_MODELS:
             raise ValueError(f"{all_kwargs['model']} not supported for chat")
@@ -147,6 +152,7 @@ class Cohere(LLM):
             chat=True,
             message=prompt,
             chat_history=history,
+            documents=documents,
             **all_kwargs,
         )
         return ChatResponse(
@@ -182,8 +188,10 @@ class Cohere(LLM):
     def stream_chat(
         self, messages: Sequence[ChatMessage], **kwargs: Any
     ) -> ChatResponseGen:
-        history = messages_to_cohere_history(messages[:-1])
         prompt = messages[-1].content
+        remaining, documents = remove_documents_from_messages(messages[:-1])
+        history = messages_to_cohere_history(remaining)
+
         all_kwargs = self._get_all_kwargs(**kwargs)
         all_kwargs["stream"] = True
         if all_kwargs["model"] not in CHAT_MODELS:
@@ -194,6 +202,7 @@ class Cohere(LLM):
             chat=True,
             message=prompt,
             chat_history=history,
+            documents=documents,
             **all_kwargs,
         )
 
diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py
index a95b7a5b3f101b8d13ba2f4dc25fd00b30b9aaf1..f85428059ba017c1217ff698ad7d1f30a12b8ad4 100644
--- a/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-cohere/llama_index/llms/cohere/utils.py
@@ -1,7 +1,12 @@
+from collections import Counter
 import logging
-from typing import Any, Callable, Dict, List, Optional, Sequence
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
+import re
 
-from llama_index.core.base.llms.types import ChatMessage
+from llama_index.core.base.llms.base import BaseLLM
+from llama_index.core.prompts import ChatPromptTemplate, ChatMessage, MessageRole
+
+from llama_index.core.prompts.chat_prompts import TEXT_QA_SYSTEM_PROMPT
 from tenacity import (
     before_sleep_log,
     retry,
@@ -30,9 +35,103 @@ REPRESENTATION_MODELS = {
 ALL_AVAILABLE_MODELS = {**COMMAND_MODELS, **GENERATION_MODELS, **REPRESENTATION_MODELS}
 CHAT_MODELS = {**COMMAND_MODELS}
 
+
+def is_cohere_model(llm: BaseLLM) -> bool:
+    from llama_index.llms.cohere import Cohere
+
+    return isinstance(llm, Cohere)
+
+
 logger = logging.getLogger(__name__)
 
 
+# Cohere models accept a special argument `documents` for RAG calls. To pass on retrieved documents to the `documents` argument
+# as intended by the Cohere team, we define:
+# 1. A new ChatMessage class, called DocumentMessage
+# 2. Specialised prompt templates for Text QA, Refine, Tree Summarize, and Refine Table that leverage DocumentMessage
+# These templates are applied by default when a retriever is called with a Cohere LLM via custom logic inside default_prompt_selectors.py.
+# See Cohere.chat for details on how the templates are unpackaged.
+
+
+class DocumentMessage(ChatMessage):
+    role: MessageRole = MessageRole.SYSTEM
+
+
+# Define new templates with DocumentMessage's to leverage Cohere's `documents` argument
+# Text QA
+COHERE_QA_TEMPLATE = ChatPromptTemplate(
+    message_templates=[
+        TEXT_QA_SYSTEM_PROMPT,
+        DocumentMessage(content="{context_str}"),
+        ChatMessage(content="{query_str}", role=MessageRole.USER),
+    ]
+)
+# Refine (based on llama_index.core.chat_prompts::CHAT_REFINE_PROMPT)
+REFINE_SYSTEM_PROMPT = ChatMessage(
+    content=(
+        "You are an expert Q&A system that strictly operates in two modes "
+        "when refining existing answers:\n"
+        "1. **Rewrite** an original answer using the new context.\n"
+        "2. **Repeat** the original answer if the new context isn't useful.\n"
+        "Never reference the original answer or context directly in your answer.\n"
+        "When in doubt, just repeat the original answer.\n"
+    ),
+    role=MessageRole.USER,
+)
+COHERE_REFINE_TEMPLATE = ChatPromptTemplate(
+    message_templates=[
+        REFINE_SYSTEM_PROMPT,
+        DocumentMessage(content="{context_msg}"),
+        ChatMessage(
+            content=(
+                "Query: {query_str}\n"
+                "Original Answer: {existing_answer}\n"
+                "New Answer: "
+            ),
+            role=MessageRole.USER,
+        ),
+    ]
+)
+# Tree summarize (based on llama_index.core.chat_prompts::CHAT_TREE_SUMMARIZE_PROMPT)
+COHERE_TREE_SUMMARIZE_TEMPLATE = ChatPromptTemplate(
+    message_templates=[
+        TEXT_QA_SYSTEM_PROMPT,
+        DocumentMessage(content="{context_str}"),
+        ChatMessage(
+            content=(
+                "Given the information from multiple sources and not prior knowledge, "
+                "answer the query.\n"
+                "Query: {query_str}\n"
+                "Answer: "
+            ),
+            role=MessageRole.USER,
+        ),
+    ]
+)
+# Table context refine (based on llama_index.core.chat_prompts::CHAT_REFINE_TABLE_CONTEXT_PROMPT)
+COHERE_REFINE_TABLE_CONTEXT_PROMPT = ChatPromptTemplate(
+    message_templates=[
+        ChatMessage(content="{query_str}", role=MessageRole.USER),
+        ChatMessage(content="{existing_answer}", role=MessageRole.ASSISTANT),
+        DocumentMessage(content="{context_msg}"),
+        ChatMessage(
+            content=(
+                "We have provided a table schema below. "
+                "---------------------\n"
+                "{schema}\n"
+                "---------------------\n"
+                "We have also provided some context information. "
+                "Given the context information and the table schema, "
+                "refine the original answer to better "
+                "answer the question. "
+                "If the context isn't useful, return the original answer."
+            ),
+            role=MessageRole.USER,
+        ),
+    ]
+)
+
+
 def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
     min_seconds = 4
     max_seconds = 10
@@ -119,6 +218,26 @@ def is_chat_model(model: str) -> bool:
     return model in COMMAND_MODELS
 
 
+def remove_documents_from_messages(
+    messages: Sequence[ChatMessage],
+) -> Tuple[Sequence[ChatMessage], Optional[List[Dict[str, str]]]]:
+    """
+    Splits messages into two lists: `remaining` and `documents`.
+    `remaining` contains all messages that aren't of type DocumentMessage (e.g. history, query).
+    `documents` contains the retrieved documents, formatted as expected by Cohere RAG inference calls.
+
+    NOTE: this will mix turns for multi-turn RAG
+    """
+    documents = []
+    remaining = []
+    for msg in messages:
+        if isinstance(msg, DocumentMessage):
+            documents.append(msg)
+        else:
+            remaining.append(msg)
+    return remaining, messages_to_cohere_documents(documents)
+
+
 def messages_to_cohere_history(
     messages: Sequence[ChatMessage],
 ) -> List[Dict[str, Optional[str]]]:
@@ -135,3 +254,61 @@ def messages_to_cohere_history(
         {"role": role_map[message.role], "message": message.content}
         for message in messages
     ]
+
+
+def messages_to_cohere_documents(
+    messages: List[DocumentMessage],
+) -> Optional[List[Dict[str, str]]]:
+    """
+    Splits out individual documents from `messages` in the format expected by Cohere.chat's `documents`.
+    Returns None if `messages` is an empty list for compatibility with co.chat (where `documents` is None by default).
+    """
+    if messages == []:
+        return None
+    documents = []
+    for msg in messages:
+        documents.extend(document_message_to_cohere_document(msg))
+    return documents
+
+
+def document_message_to_cohere_document(message: DocumentMessage) -> Dict:
+    # By construction, single DocumentMessage contains all retrieved documents
+    documents: List[Dict[str, str]] = []
+    # Capture all key: value pairs. They will be unpacked in separate documents, Cohere-style.
+    re_known_keywords = re.compile(
+        r"(file_path|file_name|file_type|file_size|creation_date|last_modified_data|last_accessed_date): (.+)\n+"
+    )
+    # Find most frequent field. We assume that the most frequent field denotes the boundary
+    # between consecutive documents, and break ties by taking whichever field appears first.
+    known_keywords = re.findall(re_known_keywords, message.content)
+    if len(known_keywords) == 0:
+        # Document doesn't contain expected special fields. Return default formatting.
+        return [{"text": message.content}]
+
+    fields_counts = Counter([key for key, _ in known_keywords])
+    most_frequent = fields_counts.most_common()[0][0]
+
+    # Initialise
+    document = None
+    remaining_text = message.content
+    for key, value in known_keywords:
+        if key == most_frequent:
+            # Save current document after extracting text, then reinit `document` to move to next document
+            if document:  # skip first iteration, where document is None
+                # Extract text up until the most_frequent remaining text, then skip to next document
+                index = remaining_text.find(key)
+                document["text"] = remaining_text[:index].strip()
+                documents.append(document)
+            document = {}
+
+        # Catch all special fields. Convert them to key: value pairs.
+        document[key] = value
+        # Store remaining text, behind the (first instance of) current `value`
+        remaining_text = remaining_text[
+            remaining_text.find(value) + len(value) :
+        ].strip()
+
+    # Append last document that's in construction
+    document["text"] = remaining_text.strip()
+    documents.append(document)
+    return documents
diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py
index f18b9558c399bd58132764752fa2182cb79cc8e2..8097adae624e5926921bbfeb1716519ec8bee51a 100644
--- a/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py
+++ b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_llms_cohere.py
@@ -1,7 +1,83 @@
+from typing import Sequence, Optional, List
+from unittest import mock
+
+import pytest
+from cohere import NonStreamedChatResponse
+
 from llama_index.core.base.llms.base import BaseLLM
-from llama_index.llms.cohere import Cohere
+from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole
+from llama_index.core.llms.mock import MockLLM
+
+from llama_index.llms.cohere import Cohere, DocumentMessage, is_cohere_model
+
+
+def test_is_cohere():
+    assert is_cohere_model(Cohere(api_key="mario"))
+    assert not is_cohere_model(MockLLM())
 
 
 def test_embedding_class():
     names_of_base_classes = [b.__name__ for b in Cohere.__mro__]
     assert BaseLLM.__name__ in names_of_base_classes
+
+
+@pytest.mark.parametrize(
+    "messages,expected_chat_history,expected_documents,expected_message",  # noqa: PT006
+    [
+        pytest.param(
+            [ChatMessage(content="Hello", role=MessageRole.USER)],
+            [],
+            None,
+            "Hello",
+            id="single user message",
+        ),
+        pytest.param(
+            [
+                ChatMessage(content="Earliest message", role=MessageRole.USER),
+                ChatMessage(content="Latest message", role=MessageRole.USER),
+            ],
+            [{"message": "Earliest message", "role": "USER"}],
+            None,
+            "Latest message",
+            id="messages with chat history",
+        ),
+        pytest.param(
+            [
+                ChatMessage(content="Earliest message", role=MessageRole.USER),
+                DocumentMessage(content="Document content"),
+                ChatMessage(content="Latest message", role=MessageRole.USER),
+            ],
+            [{"message": "Earliest message", "role": "USER"}],
+            [{"text": "Document content"}],
+            "Latest message",
+            id="messages with chat history",
+        ),
+    ],
+)
+def test_chat(
+    messages: Sequence[ChatMessage],
+    expected_chat_history: Optional[List],
+    expected_documents: Optional[List],
+    expected_message: str,
+):
+    # Mock the API client.
+    with mock.patch("llama_index.llms.cohere.base.cohere.Client", autospec=True):
+        llm = Cohere(api_key="dummy", temperature=0.3)
+    # Mock the API response.
+    llm._client.chat.return_value = NonStreamedChatResponse(text="Placeholder reply")
+    expected = ChatResponse(
+        message=ChatMessage(role=MessageRole.ASSISTANT, content="Placeholder reply"),
+        raw=llm._client.chat.return_value.__dict__,
+    )
+
+    actual = llm.chat(messages)
+
+    assert expected == actual
+    # Assert that the mocked API client was called in the expected way.
+    llm._client.chat.assert_called_once_with(
+        chat_history=expected_chat_history,
+        documents=expected_documents,
+        message=expected_message,
+        model="command-r",
+        temperature=0.3,
+    )
diff --git a/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_rag_inference.py b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_rag_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dbdaf040d27bb4bd29af489d108386c0a539203
--- /dev/null
+++ b/llama-index-integrations/llms/llama-index-llms-cohere/tests/test_rag_inference.py
@@ -0,0 +1,60 @@
+import pytest
+
+from llama_index.core.prompts import MessageRole
+from llama_index.llms.cohere import DocumentMessage
+from llama_index.llms.cohere.utils import document_message_to_cohere_document
+
+text1 = "birds flying high"
+text2 = "sun in the sky"
+text3 = "breeze driftin' on by"
+text4 = "fish in the sea"
+text5 = "river running free"
+texts = [text1, text2, text3, text4, text5]
+
+
+@pytest.mark.parametrize(
+    "message, expected",  # noqa: PT006
+    [
+        pytest.param(
+            DocumentMessage(
+                role=MessageRole.USER,
+                content="\n\n".join(
+                    [f"file_path: nina.txt\n\n{text}" for text in texts]
+                ),
+                additional_kwargs={},
+            ),
+            [{"file_path": "nina.txt", "text": text} for text in texts],
+            id="single field, multiple documents",
+        ),
+        pytest.param(
+            DocumentMessage(
+                role=MessageRole.USER,
+                content="\n\n".join(
+                    [
+                        f"file_path: nina.txt\n\nfile_name: greatest-hits\n\n{text}"
+                        for text in texts
+                    ]
+                ),
+                additional_kwargs={},
+            ),
+            [
+                {"file_path": "nina.txt", "file_name": "greatest-hits", "text": text}
+                for text in texts
+            ],
+            id="multiple fields (same count), multiple documents",
+        ),
+        pytest.param(
+            DocumentMessage(
+                role=MessageRole.USER,
+                content="\n\n".join(texts),
+                additional_kwargs={},
+            ),
+            [{"text": "\n\n".join(texts)}],
+            id="no fields (just text), multiple documents",
+        ),
+    ],
+)
+def test_document_message_to_cohere_document(message, expected):
+    res = document_message_to_cohere_document(message)
+    print(res)
+    assert res == expected