From defdfd8f4c536881754835d285a80595c760bb86 Mon Sep 17 00:00:00 2001
From: Souyama <souyamadebnath@gmail.com>
Date: Tue, 19 Mar 2024 20:56:52 +0530
Subject: [PATCH] Fix: Anthropic LLM merge consecutive messages with same role
 (#12013)

---
 .../llama_index/llms/anthropic/utils.py       | 39 +++++++++++++++++--
 .../llama-index-llms-anthropic/pyproject.toml |  2 +-
 .../llama-index-llms-bedrock/pyproject.toml   |  4 +-
 .../tests/test_bedrock.py                     |  4 +-
 4 files changed, 41 insertions(+), 8 deletions(-)

diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py
index c7056d1084..28f5aba9ea 100644
--- a/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-anthropic/llama_index/llms/anthropic/utils.py
@@ -2,6 +2,8 @@ from typing import Dict, Sequence, Tuple
 
 from llama_index.core.base.llms.types import ChatMessage, MessageRole
 
+from anthropic.types import MessageParam, TextBlockParam
+
 HUMAN_PREFIX = "\n\nHuman:"
 ASSISTANT_PREFIX = "\n\nAssistant:"
 
@@ -27,18 +29,49 @@ def anthropic_modelname_to_contextsize(modelname: str) -> int:
     return CLAUDE_MODELS[modelname]
 
 
+def __merge_common_role_msgs(
+    messages: Sequence[MessageParam],
+) -> Sequence[MessageParam]:
+    """Merge consecutive messages with the same role."""
+    postprocessed_messages: Sequence[MessageParam] = []
+    for message in messages:
+        if (
+            postprocessed_messages
+            and postprocessed_messages[-1]["role"] == message["role"]
+        ):
+            postprocessed_messages[-1]["content"] += message["content"]
+        else:
+            postprocessed_messages.append(message)
+    return postprocessed_messages
+
+
 def messages_to_anthropic_messages(
     messages: Sequence[ChatMessage],
-) -> Tuple[Sequence[ChatMessage], str]:
+) -> Tuple[Sequence[MessageParam], str]:
+    """Converts a list of generic ChatMessages to anthropic messages.
+
+    Args:
+        messages: List of ChatMessages
+
+    Returns:
+        Tuple of:
+        - List of anthropic messages
+        - System prompt
+    """
     anthropic_messages = []
     system_prompt = ""
     for message in messages:
         if message.role == MessageRole.SYSTEM:
             system_prompt = message.content
         else:
-            message = {"role": message.role.value, "content": message.content}
+            message = MessageParam(
+                role=message.role.value,
+                content=[
+                    TextBlockParam(text=message.content, type="text")
+                ],  # TODO: type detect for multimodal
+            )
             anthropic_messages.append(message)
-    return anthropic_messages, system_prompt
+    return __merge_common_role_msgs(anthropic_messages), system_prompt
 
 
 # Function used in bedrock
diff --git a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml
index 0122456f04..5f5f9cba97 100644
--- a/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-anthropic/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-anthropic"
 readme = "README.md"
-version = "0.1.6"
+version = "0.1.7"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml
index efb2570fec..8b42e97e0a 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml
@@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-bedrock"
 readme = "README.md"
-version = "0.1.4"
+version = "0.1.5"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
 llama-index-core = "^0.10.1"
-llama-index-llms-anthropic = "^0.1.6"
+llama-index-llms-anthropic = "^0.1.7"
 boto3 = "^1.34.26"
 
 [tool.poetry.group.dev.dependencies]
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
index 96e2c5d648..6267ed7dc0 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
@@ -102,10 +102,10 @@ class MockStreamCompletionWithRetry:
         ),
         (
             "anthropic.claude-instant-v1",
-            '{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", '
+            '{"messages": [{"role": "user", "content": [{"text": "test prompt", "type": "text"}]}], "anthropic_version": "bedrock-2023-05-31", '
             '"temperature": 0.1, "max_tokens": 512}',
             '{"content": [{"text": "\\n\\nThis is indeed a test", "type": "text"}]}',
-            '{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", '
+            '{"messages": [{"role": "user", "content": [{"text": "test prompt", "type": "text"}]}], "anthropic_version": "bedrock-2023-05-31", '
             '"temperature": 0.1, "max_tokens": 512}',
         ),
         (
-- 
GitLab