Skip to content
Snippets Groups Projects
Unverified Commit 0e7d6684 authored by Bryce Freshcorn's avatar Bryce Freshcorn Committed by GitHub
Browse files

Add Claude 3 Sonnet model to AWS Bedrock & update to support Messages API -...

Add Claude 3 Sonnet model to AWS Bedrock & update to support Messages API - Text Only, Not Multimodal (#11663)
parent cf5f75a2
No related branches found
No related tags found
No related merge requests found
...@@ -216,7 +216,7 @@ def llm_completion_callback() -> Callable: ...@@ -216,7 +216,7 @@ def llm_completion_callback() -> Callable:
dispatcher.event( dispatcher.event(
LLMCompletionStartEvent( LLMCompletionStartEvent(
model_dict=_self.to_dict(), model_dict=_self.to_dict(),
prompt=args[0], prompt=str(args[0]),
additional_kwargs=kwargs, additional_kwargs=kwargs,
) )
) )
...@@ -238,7 +238,7 @@ def llm_completion_callback() -> Callable: ...@@ -238,7 +238,7 @@ def llm_completion_callback() -> Callable:
async for x in f_return_val: async for x in f_return_val:
dispatcher.event( dispatcher.event(
LLMCompletionEndEvent( LLMCompletionEndEvent(
prompt=args[0], prompt=str(args[0]),
response=x, response=x,
) )
) )
...@@ -266,7 +266,7 @@ def llm_completion_callback() -> Callable: ...@@ -266,7 +266,7 @@ def llm_completion_callback() -> Callable:
) )
dispatcher.event( dispatcher.event(
LLMCompletionEndEvent( LLMCompletionEndEvent(
prompt=args[0], prompt=str(args[0]),
response=f_return_val, response=f_return_val,
) )
) )
...@@ -278,7 +278,7 @@ def llm_completion_callback() -> Callable: ...@@ -278,7 +278,7 @@ def llm_completion_callback() -> Callable:
dispatcher.event( dispatcher.event(
LLMCompletionStartEvent( LLMCompletionStartEvent(
model_dict=_self.to_dict(), model_dict=_self.to_dict(),
prompt=args[0], prompt=str(args[0]),
additional_kwargs=kwargs, additional_kwargs=kwargs,
) )
) )
...@@ -299,7 +299,7 @@ def llm_completion_callback() -> Callable: ...@@ -299,7 +299,7 @@ def llm_completion_callback() -> Callable:
for x in f_return_val: for x in f_return_val:
dispatcher.event( dispatcher.event(
LLMCompletionEndEvent( LLMCompletionEndEvent(
prompt=args[0], prompt=str(args[0]),
response=x, response=x,
) )
) )
...@@ -327,7 +327,7 @@ def llm_completion_callback() -> Callable: ...@@ -327,7 +327,7 @@ def llm_completion_callback() -> Callable:
) )
dispatcher.event( dispatcher.event(
LLMCompletionEndEvent( LLMCompletionEndEvent(
prompt=args[0], prompt=str(args[0]),
response=f_return_val, response=f_return_val,
) )
) )
......
...@@ -27,6 +27,7 @@ from llama_index.core.base.llms.generic_utils import ( ...@@ -27,6 +27,7 @@ from llama_index.core.base.llms.generic_utils import (
from llama_index.core.llms.llm import LLM from llama_index.core.llms.llm import LLM
from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.llms.bedrock.utils import ( from llama_index.llms.bedrock.utils import (
AnthropicProvider,
BEDROCK_FOUNDATION_LLMS, BEDROCK_FOUNDATION_LLMS,
CHAT_ONLY_MODELS, CHAT_ONLY_MODELS,
STREAMING_MODELS, STREAMING_MODELS,
...@@ -198,6 +199,8 @@ class Bedrock(LLM): ...@@ -198,6 +199,8 @@ class Bedrock(LLM):
"temperature": self.temperature, "temperature": self.temperature,
self._provider.max_tokens_key: self.max_tokens, self._provider.max_tokens_key: self.max_tokens,
} }
if type(self._provider) is AnthropicProvider and self.system_prompt:
base_kwargs["system"] = self.system_prompt
return { return {
**base_kwargs, **base_kwargs,
**self.additional_kwargs, **self.additional_kwargs,
......
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Sequence from typing import Any, Callable, Dict, List, Optional, Sequence
from llama_index.core.base.llms.types import ChatMessage from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.generic_utils import ( from llama_index.core.base.llms.generic_utils import (
prompt_to_messages, prompt_to_messages,
) )
from llama_index.llms.anthropic.utils import messages_to_anthropic_prompt from llama_index.llms.anthropic.utils import messages_to_anthropic_messages
from llama_index.llms.bedrock.llama_utils import ( from llama_index.llms.bedrock.llama_utils import (
completion_to_prompt as completion_to_llama_prompt, completion_to_prompt as completion_to_llama_prompt,
) )
...@@ -44,6 +44,8 @@ CHAT_ONLY_MODELS = { ...@@ -44,6 +44,8 @@ CHAT_ONLY_MODELS = {
"anthropic.claude-v1": 100000, "anthropic.claude-v1": 100000,
"anthropic.claude-v2": 100000, "anthropic.claude-v2": 100000,
"anthropic.claude-v2:1": 200000, "anthropic.claude-v2:1": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"meta.llama2-13b-chat-v1": 2048, "meta.llama2-13b-chat-v1": 2048,
"meta.llama2-70b-chat-v1": 4096, "meta.llama2-70b-chat-v1": 4096,
} }
...@@ -59,6 +61,8 @@ STREAMING_MODELS = { ...@@ -59,6 +61,8 @@ STREAMING_MODELS = {
"anthropic.claude-v1", "anthropic.claude-v1",
"anthropic.claude-v2", "anthropic.claude-v2",
"anthropic.claude-v2:1", "anthropic.claude-v2:1",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"meta.llama2-13b-chat-v1", "meta.llama2-13b-chat-v1",
} }
...@@ -107,18 +111,53 @@ class Ai21Provider(Provider): ...@@ -107,18 +111,53 @@ class Ai21Provider(Provider):
def completion_to_anthopic_prompt(completion: str) -> str: def completion_to_anthopic_prompt(completion: str) -> str:
return messages_to_anthropic_prompt(prompt_to_messages(completion)) messages, _ = messages_to_anthropic_messages(prompt_to_messages(completion))
return messages
def _messages_to_anthropic_messages(messages: Sequence[ChatMessage]) -> List[dict]:
messages, system_prompt = messages_to_anthropic_messages(messages)
if system_prompt:
messages = [{"role": "system", "content": system_prompt}, *messages]
return messages
class AnthropicProvider(Provider): class AnthropicProvider(Provider):
max_tokens_key = "max_tokens_to_sample" max_tokens_key = "max_tokens"
def __init__(self) -> None: def __init__(self) -> None:
self.messages_to_prompt = messages_to_anthropic_prompt self.messages_to_prompt = _messages_to_anthropic_messages
self.completion_to_prompt = completion_to_anthopic_prompt self.completion_to_prompt = completion_to_anthopic_prompt
def get_text_from_stream_response(self, response: dict) -> str:
if response["type"] == "content_block_delta":
return response["delta"]["text"]
else:
return ""
def get_text_from_response(self, response: dict) -> str: def get_text_from_response(self, response: dict) -> str:
return response["completion"] return response["content"][0]["text"]
def get_request_body(self, prompt: Sequence[Dict], inference_parameters: dict):
if len(prompt) > 0 and prompt[0]["role"] == "system":
system_message = prompt[0]["content"]
prompt = prompt[1:]
if (
"system" in inference_parameters
and inference_parameters["system"] is not None
):
inference_parameters["system"] += system_message
else:
inference_parameters["system"] = system_message
return {
"messages": prompt,
"anthropic_version": inference_parameters.get(
"anthropic_version", "bedrock-2023-05-31"
), # Required by AWS.
**inference_parameters,
}
class CohereProvider(Provider): class CohereProvider(Provider):
......
...@@ -27,12 +27,12 @@ exclude = ["**/BUILD"] ...@@ -27,12 +27,12 @@ exclude = ["**/BUILD"]
license = "MIT" license = "MIT"
name = "llama-index-llms-bedrock" name = "llama-index-llms-bedrock"
readme = "README.md" readme = "README.md"
version = "0.1.3" version = "0.1.4"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1" llama-index-core = "^0.10.1"
llama-index-llms-anthropic = "^0.1.1" llama-index-llms-anthropic = "^0.1.6"
boto3 = "^1.34.26" boto3 = "^1.34.26"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
......
...@@ -102,9 +102,11 @@ class MockStreamCompletionWithRetry: ...@@ -102,9 +102,11 @@ class MockStreamCompletionWithRetry:
), ),
( (
"anthropic.claude-instant-v1", "anthropic.claude-instant-v1",
'{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}', '{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", '
'{"completion": "\\n\\nThis is indeed a test"}', '"temperature": 0.1, "max_tokens": 512}',
'{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}', '{"content": [{"text": "\\n\\nThis is indeed a test", "type": "text"}]}',
'{"messages": [{"role": "user", "content": "test prompt"}], "anthropic_version": "bedrock-2023-05-31", '
'"temperature": 0.1, "max_tokens": 512}',
), ),
( (
"meta.llama2-13b-chat-v1", "meta.llama2-13b-chat-v1",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment