From 0e7d6684409bd58c986b23934460d47a723a1903 Mon Sep 17 00:00:00 2001 From: Bryce Freshcorn <26725654+brycecf@users.noreply.github.com> Date: Fri, 15 Mar 2024 22:36:35 -0400 Subject: [PATCH] Add Claude 3 Sonnet model to AWS Bedrock & update to support Messages API - Text Only, Not Multimodal (#11663) --- .../llama_index/core/llms/callbacks.py | 12 ++--- .../llama_index/llms/bedrock/base.py | 3 ++ .../llama_index/llms/bedrock/utils.py | 51 ++++++++++++++++--- .../llama-index-llms-bedrock/pyproject.toml | 4 +- .../tests/test_bedrock.py | 8 +-- 5 files changed, 61 insertions(+), 17 deletions(-) diff --git a/llama-index-core/llama_index/core/llms/callbacks.py b/llama-index-core/llama_index/core/llms/callbacks.py index 1d25ce4dc..d63a0327f 100644 --- a/llama-index-core/llama_index/core/llms/callbacks.py +++ b/llama-index-core/llama_index/core/llms/callbacks.py @@ -216,7 +216,7 @@ def llm_completion_callback() -> Callable: dispatcher.event( LLMCompletionStartEvent( model_dict=_self.to_dict(), - prompt=args[0], + prompt=str(args[0]), additional_kwargs=kwargs, ) ) @@ -238,7 +238,7 @@ def llm_completion_callback() -> Callable: async for x in f_return_val: dispatcher.event( LLMCompletionEndEvent( - prompt=args[0], + prompt=str(args[0]), response=x, ) ) @@ -266,7 +266,7 @@ def llm_completion_callback() -> Callable: ) dispatcher.event( LLMCompletionEndEvent( - prompt=args[0], + prompt=str(args[0]), response=f_return_val, ) ) @@ -278,7 +278,7 @@ def llm_completion_callback() -> Callable: dispatcher.event( LLMCompletionStartEvent( model_dict=_self.to_dict(), - prompt=args[0], + prompt=str(args[0]), additional_kwargs=kwargs, ) ) @@ -299,7 +299,7 @@ def llm_completion_callback() -> Callable: for x in f_return_val: dispatcher.event( LLMCompletionEndEvent( - prompt=args[0], + prompt=str(args[0]), response=x, ) ) @@ -327,7 +327,7 @@ def llm_completion_callback() -> Callable: ) dispatcher.event( LLMCompletionEndEvent( - prompt=args[0], + prompt=str(args[0]), response=f_return_val, ) ) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py index ae9c1b4b7..e102cc698 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py @@ -27,6 +27,7 @@ from llama_index.core.base.llms.generic_utils import ( from llama_index.core.llms.llm import LLM from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.bedrock.utils import ( + AnthropicProvider, BEDROCK_FOUNDATION_LLMS, CHAT_ONLY_MODELS, STREAMING_MODELS, @@ -198,6 +199,8 @@ class Bedrock(LLM): "temperature": self.temperature, self._provider.max_tokens_key: self.max_tokens, } + if type(self._provider) is AnthropicProvider and self.system_prompt: + base_kwargs["system"] = self.system_prompt return { **base_kwargs, **self.additional_kwargs, diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py index df1f3a484..6c66942b9 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py @@ -1,12 +1,12 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence from llama_index.core.base.llms.types import ChatMessage from llama_index.core.base.llms.generic_utils import ( prompt_to_messages, ) -from llama_index.llms.anthropic.utils import messages_to_anthropic_prompt +from llama_index.llms.anthropic.utils import messages_to_anthropic_messages from llama_index.llms.bedrock.llama_utils import ( completion_to_prompt as completion_to_llama_prompt, ) @@ -44,6 +44,8 @@ CHAT_ONLY_MODELS = { "anthropic.claude-v1": 100000, "anthropic.claude-v2": 100000, "anthropic.claude-v2:1": 200000, + "anthropic.claude-3-sonnet-20240229-v1:0": 200000, + "anthropic.claude-3-haiku-20240307-v1:0": 200000, "meta.llama2-13b-chat-v1": 2048, "meta.llama2-70b-chat-v1": 4096, } @@ -59,6 +61,8 @@ STREAMING_MODELS = { "anthropic.claude-v1", "anthropic.claude-v2", "anthropic.claude-v2:1", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-3-haiku-20240307-v1:0", "meta.llama2-13b-chat-v1", } @@ -107,18 +111,53 @@ class Ai21Provider(Provider): def completion_to_anthopic_prompt(completion: str) -> str: - return messages_to_anthropic_prompt(prompt_to_messages(completion)) + messages, _ = messages_to_anthropic_messages(prompt_to_messages(completion)) + return messages + + +def _messages_to_anthropic_messages(messages: Sequence[ChatMessage]) -> List[dict]: + messages, system_prompt = messages_to_anthropic_messages(messages) + if system_prompt: + messages = [{"role": "system", "content": system_prompt}, *messages] + return messages class AnthropicProvider(Provider): - max_tokens_key = "max_tokens_to_sample" + max_tokens_key = "max_tokens" def __init__(self) -> None: - self.messages_to_prompt = messages_to_anthropic_prompt + self.messages_to_prompt = _messages_to_anthropic_messages self.completion_to_prompt = completion_to_anthopic_prompt + def get_text_from_stream_response(self, response: dict) -> str: + if response["type"] == "content_block_delta": + return response["delta"]["text"] + else: + return "" + def get_text_from_response(self, response: dict) -> str: - return response["completion"] + return response["content"][0]["text"] + + def get_request_body(self, prompt: Sequence[Dict], inference_parameters: dict): + if len(prompt) > 0 and prompt[0]["role"] == "system": + system_message = prompt[0]["content"] + prompt = prompt[1:] + + if ( + "system" in inference_parameters + and inference_parameters["system"] is not None + ): + inference_parameters["system"] += system_message + else: + inference_parameters["system"] = system_message + + return { + "messages": prompt, + "anthropic_version": inference_parameters.get( + "anthropic_version", "bedrock-2023-05-31" + ), # Required by AWS. + **inference_parameters, + } class CohereProvider(Provider): diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml index d2320280a..efb2570fe 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/pyproject.toml @@ -27,12 +27,12 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-bedrock" readme = "README.md" -version = "0.1.3" +version = "0.1.4" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" llama-index-core = "^0.10.1" -llama-index-llms-anthropic = "^0.1.1" +llama-index-llms-anthropic = "^0.1.6" boto3 = "^1.34.26" [tool.poetry.group.dev.dependencies] diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py index 74b0df855..c5f43bd57 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py @@ -102,9 +102,11 @@ class MockStreamCompletionWithRetry: ), ( "anthropic.claude-instant-v1", - '{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}', - '{"completion": "\\n\\nThis is indeed a test"}', - '{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}', + '{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", ' + '"temperature": 0.1, "max_tokens": 512}', + '{"content": [{"text": "\\n\\nThis is indeed a test", "type": "text"}]}', + '{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", ' + '"temperature": 0.1, "max_tokens": 512}', ), ( "meta.llama2-13b-chat-v1", -- GitLab