From 0e7d6684409bd58c986b23934460d47a723a1903 Mon Sep 17 00:00:00 2001
From: Bryce Freshcorn <26725654+brycecf@users.noreply.github.com>
Date: Fri, 15 Mar 2024 22:36:35 -0400
Subject: [PATCH] Add Claude 3 Sonnet model to AWS Bedrock & update to support
 Messages API - Text Only, Not Multimodal (#11663)

---
 .../llama_index/core/llms/callbacks.py        | 12 ++---
 .../llama_index/llms/bedrock/base.py          |  3 ++
 .../llama_index/llms/bedrock/utils.py         | 51 ++++++++++++++++---
 .../llama-index-llms-bedrock/pyproject.toml   |  4 +-
 .../tests/test_bedrock.py                     |  8 +--
 5 files changed, 61 insertions(+), 17 deletions(-)

diff --git a/llama-index-core/llama_index/core/llms/callbacks.py b/llama-index-core/llama_index/core/llms/callbacks.py
index 1d25ce4dc..d63a0327f 100644
--- a/llama-index-core/llama_index/core/llms/callbacks.py
+++ b/llama-index-core/llama_index/core/llms/callbacks.py
@@ -216,7 +216,7 @@ def llm_completion_callback() -> Callable:
                 dispatcher.event(
                     LLMCompletionStartEvent(
                         model_dict=_self.to_dict(),
-                        prompt=args[0],
+                        prompt=str(args[0]),
                         additional_kwargs=kwargs,
                     )
                 )
@@ -238,7 +238,7 @@ def llm_completion_callback() -> Callable:
                         async for x in f_return_val:
                             dispatcher.event(
                                 LLMCompletionEndEvent(
-                                    prompt=args[0],
+                                    prompt=str(args[0]),
                                     response=x,
                                 )
                             )
@@ -266,7 +266,7 @@ def llm_completion_callback() -> Callable:
                     )
                     dispatcher.event(
                         LLMCompletionEndEvent(
-                            prompt=args[0],
+                            prompt=str(args[0]),
                             response=f_return_val,
                         )
                     )
@@ -278,7 +278,7 @@ def llm_completion_callback() -> Callable:
                 dispatcher.event(
                     LLMCompletionStartEvent(
                         model_dict=_self.to_dict(),
-                        prompt=args[0],
+                        prompt=str(args[0]),
                         additional_kwargs=kwargs,
                     )
                 )
@@ -299,7 +299,7 @@ def llm_completion_callback() -> Callable:
                         for x in f_return_val:
                             dispatcher.event(
                                 LLMCompletionEndEvent(
-                                    prompt=args[0],
+                                    prompt=str(args[0]),
                                     response=x,
                                 )
                             )
@@ -327,7 +327,7 @@ def llm_completion_callback() -> Callable:
                     )
                     dispatcher.event(
                         LLMCompletionEndEvent(
-                            prompt=args[0],
+                            prompt=str(args[0]),
                             response=f_return_val,
                         )
                     )
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py
index ae9c1b4b7..e102cc698 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py
@@ -27,6 +27,7 @@ from llama_index.core.base.llms.generic_utils import (
 from llama_index.core.llms.llm import LLM
 from llama_index.core.types import BaseOutputParser, PydanticProgramMode
 from llama_index.llms.bedrock.utils import (
+    AnthropicProvider,
     BEDROCK_FOUNDATION_LLMS,
     CHAT_ONLY_MODELS,
     STREAMING_MODELS,
@@ -198,6 +199,8 @@ class Bedrock(LLM):
             "temperature": self.temperature,
             self._provider.max_tokens_key: self.max_tokens,
         }
+        if type(self._provider) is AnthropicProvider and self.system_prompt:
+            base_kwargs["system"] = self.system_prompt
         return {
             **base_kwargs,
             **self.additional_kwargs,
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py
index df1f3a484..6c66942b9 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py
@@ -1,12 +1,12 @@
 import logging
 from abc import ABC, abstractmethod
-from typing import Any, Callable, Optional, Sequence
+from typing import Any, Callable, Dict, List, Optional, Sequence
 
 from llama_index.core.base.llms.types import ChatMessage
 from llama_index.core.base.llms.generic_utils import (
     prompt_to_messages,
 )
-from llama_index.llms.anthropic.utils import messages_to_anthropic_prompt
+from llama_index.llms.anthropic.utils import messages_to_anthropic_messages
 from llama_index.llms.bedrock.llama_utils import (
     completion_to_prompt as completion_to_llama_prompt,
 )
@@ -44,6 +44,8 @@ CHAT_ONLY_MODELS = {
     "anthropic.claude-v1": 100000,
     "anthropic.claude-v2": 100000,
     "anthropic.claude-v2:1": 200000,
+    "anthropic.claude-3-sonnet-20240229-v1:0": 200000,
+    "anthropic.claude-3-haiku-20240307-v1:0": 200000,
     "meta.llama2-13b-chat-v1": 2048,
     "meta.llama2-70b-chat-v1": 4096,
 }
@@ -59,6 +61,8 @@ STREAMING_MODELS = {
     "anthropic.claude-v1",
     "anthropic.claude-v2",
     "anthropic.claude-v2:1",
+    "anthropic.claude-3-sonnet-20240229-v1:0",
+    "anthropic.claude-3-haiku-20240307-v1:0",
     "meta.llama2-13b-chat-v1",
 }
 
@@ -107,18 +111,53 @@ class Ai21Provider(Provider):
 
 
 def completion_to_anthopic_prompt(completion: str) -> str:
-    return messages_to_anthropic_prompt(prompt_to_messages(completion))
+    messages, _ = messages_to_anthropic_messages(prompt_to_messages(completion))
+    return messages
+
+
+def _messages_to_anthropic_messages(messages: Sequence[ChatMessage]) -> List[dict]:
+    messages, system_prompt = messages_to_anthropic_messages(messages)
+    if system_prompt:
+        messages = [{"role": "system", "content": system_prompt}, *messages]
+    return messages
 
 
 class AnthropicProvider(Provider):
-    max_tokens_key = "max_tokens_to_sample"
+    max_tokens_key = "max_tokens"
 
     def __init__(self) -> None:
-        self.messages_to_prompt = messages_to_anthropic_prompt
+        self.messages_to_prompt = _messages_to_anthropic_messages
         self.completion_to_prompt = completion_to_anthopic_prompt
 
+    def get_text_from_stream_response(self, response: dict) -> str:
+        if response["type"] == "content_block_delta":
+            return response["delta"]["text"]
+        else:
+            return ""
+
     def get_text_from_response(self, response: dict) -> str:
-        return response["completion"]
+        return response["content"][0]["text"]
+
+    def get_request_body(self, prompt: Sequence[Dict], inference_parameters: dict):
+        if len(prompt) > 0 and prompt[0]["role"] == "system":
+            system_message = prompt[0]["content"]
+            prompt = prompt[1:]
+
+            if (
+                "system" in inference_parameters
+                and inference_parameters["system"] is not None
+            ):
+                inference_parameters["system"] += system_message
+            else:
+                inference_parameters["system"] = system_message
+
+        return {
+            "messages": prompt,
+            "anthropic_version": inference_parameters.get(
+                "anthropic_version", "bedrock-2023-05-31"
+            ),  # Required by AWS.
+            **inference_parameters,
+        }
 
 
 class CohereProvider(Provider):
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml
index d2320280a..efb2570fe 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml
@@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-llms-bedrock"
 readme = "README.md"
-version = "0.1.3"
+version = "0.1.4"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
 llama-index-core = "^0.10.1"
-llama-index-llms-anthropic = "^0.1.1"
+llama-index-llms-anthropic = "^0.1.6"
 boto3 = "^1.34.26"
 
 [tool.poetry.group.dev.dependencies]
diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
index 74b0df855..c5f43bd57 100644
--- a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
+++ b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py
@@ -102,9 +102,11 @@ class MockStreamCompletionWithRetry:
         ),
         (
             "anthropic.claude-instant-v1",
-            '{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}',
-            '{"completion": "\\n\\nThis is indeed a test"}',
-            '{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}',
+            '{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", '
+            '"temperature": 0.1, "max_tokens": 512}',
+            '{"content": [{"text": "\\n\\nThis is indeed a test", "type": "text"}]}',
+            '{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", '
+            '"temperature": 0.1, "max_tokens": 512}',
         ),
         (
             "meta.llama2-13b-chat-v1",
-- 
GitLab