Skip to content
Snippets Groups Projects
Unverified Commit 77e80b7f authored by Souyama's avatar Souyama Committed by GitHub
Browse files

Fix: calling AWS Bedrock models (#10443)

parent e39be29c
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
<a href="https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/embeddings/bedrock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
%% Cell type:markdown id: tags:
# Bedrock Embeddings
If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.
%% Cell type:code id: tags:
``` python
import os
from llama_index.embeddings import BedrockEmbedding
```
%% Cell type:code id: tags:
``` python
embed_model = BedrockEmbedding.from_credentials(
embed_model = BedrockEmbedding(
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_token=os.getenv("AWS_SESSION_TOKEN"),
aws_region="<aws-region>",
aws_profile="<aws-profile>",
region_name="<aws-region>",
profile_name="<aws-profile>",
)
```
%% Cell type:code id: tags:
``` python
embedding = embed_model.get_text_embedding("hello world")
print(embedding)
```
%% Cell type:markdown id: tags:
## List supported models
To check list of supported models of Amazon Bedrock on LlamaIndex, call `BedrockEmbedding.list_supported_models()` as follows.
%% Cell type:code id: tags:
``` python
from llama_index.embeddings import BedrockEmbedding
import json
supported_models = BedrockEmbedding.list_supported_models()
print(json.dumps(supported_models, indent=2))
```
%% Output
{
"amazon": [
"amazon.titan-embed-text-v1",
"amazon.titan-embed-g1-text-02",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3"
],
"cohere": [
"amazon.titan-embed-text-v1",
"amazon.titan-embed-g1-text-02",
"cohere.embed-english-v3",
"cohere.embed-multilingual-v3"
]
}
%% Cell type:markdown id: tags:
## Provider: Amazon
Amazon Bedrock Titan embeddings.
%% Cell type:code id: tags:
``` python
from llama_index.embeddings import BedrockEmbedding
model = BedrockEmbedding().from_credentials(
model_name="amazon.titan-embed-g1-text-02"
)
model = BedrockEmbedding(model_name="amazon.titan-embed-g1-text-02")
embeddings = model.get_text_embedding("hello world")
print(embeddings)
```
%% Cell type:markdown id: tags:
## Provider: Cohere
### cohere.embed-english-v3
%% Cell type:markdown id: tags:
Embed text for search
%% Cell type:code id: tags:
``` python
model = BedrockEmbedding().from_credentials(
model_name="cohere.embed-english-v3"
)
coherePayload = {
"texts": ["This is a test document", "This is another test document"],
"input_type": "search_document",
"truncate": "NONE",
}
embeddings = model.get_text_embedding(coherePayload)
model = BedrockEmbedding(model_name="cohere.embed-english-v3")
coherePayload = ["This is a test document", "This is another test document"]
embeddings = model.get_text_embedding_batch(coherePayload)
print(embeddings)
```
%% Cell type:markdown id: tags:
Embed query for question answering
%% Cell type:code id: tags:
``` python
model = BedrockEmbedding(model_name="cohere.embed-english-v3")
coherePayload = "What is gravity?"
embeddings = model._get_query_embedding(coherePayload)
print(embeddings)
```
%% Cell type:markdown id: tags:
### MultiLingual Embeddings from Cohere
%% Cell type:code id: tags:
``` python
model = BedrockEmbedding().from_credentials(
model_name="cohere.embed-multilingual-v3"
)
coherePayload = {
"texts": [
"This is a test document",
"తెలుగు అనేది ద్రావిడ భాషల కుటుంబానికి చెందిన భాష.",
],
"input_type": "search_document",
"truncate": "NONE",
}
embeddings = model.get_text_embedding(coherePayload)
model = BedrockEmbedding(model_name="cohere.embed-multilingual-v3")
coherePayload = [
"This is a test document",
"తెలుగు అనేది ద్రావిడ భాషల కుటుంబానికి చెందిన భాష.",
]
embeddings = model.get_text_embedding_batch(coherePayload)
print(embeddings)
```
......
This diff is collapsed.
import json
import os
import warnings
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks.base import CallbackManager
from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.embeddings.base import BaseEmbedding, Embedding
from llama_index.core.llms.types import ChatMessage
from llama_index.types import BaseOutputParser, PydanticProgramMode
class PROVIDERS(str, Enum):
......@@ -33,169 +32,140 @@ PROVIDER_SPECIFIC_IDENTIFIERS = {
class BedrockEmbedding(BaseEmbedding):
model: str = Field(description="The modelId of the Bedrock model to use.")
profile_name: Optional[str] = Field(
description="The name of aws profile to use. If not given, then the default profile is used.",
exclude=True,
)
aws_access_key_id: Optional[str] = Field(
description="AWS Access Key ID to use", exclude=True
)
aws_secret_access_key: Optional[str] = Field(
description="AWS Secret Access Key to use", exclude=True
)
aws_session_token: Optional[str] = Field(
description="AWS Session Token to use", exclude=True
)
region_name: Optional[str] = Field(
description="AWS region name to use. Uses region configured in AWS CLI if not passed",
exclude=True,
)
botocore_session: Optional[Any] = Field(
description="Use this Botocore session instead of creating a new default one.",
exclude=True,
)
botocore_config: Optional[Any] = Field(
description="Custom configuration object to use instead of the default generated one.",
exclude=True,
)
max_retries: int = Field(
default=10, description="The maximum number of API retries.", gt=0
)
timeout: float = Field(
default=60.0,
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the bedrock client."
)
_client: Any = PrivateAttr()
_verbose: bool = PrivateAttr()
def __init__(
self,
model_name: str = Models.TITAN_EMBEDDING,
client: Any = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
):
self._client = client
self._verbose = verbose
super().__init__(
model_name=model_name,
client=client,
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
)
@staticmethod
def list_supported_models() -> Dict[str, List[str]]:
list_models = {}
for provider in PROVIDERS:
list_models[provider.value] = [m.value for m in Models]
return list_models
@classmethod
def class_name(self) -> str:
return "BedrockEmbedding"
def set_credentials(
self,
aws_region: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_profile: Optional[str] = None,
) -> None:
aws_region = aws_region or os.getenv("AWS_REGION")
aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = aws_secret_access_key or os.getenv(
"AWS_SECRET_ACCESS_KEY"
)
aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
if aws_region is None:
warnings.warn(
"AWS_REGION not found. Set environment variable AWS_REGION or set aws_region"
)
if aws_access_key_id is None:
warnings.warn(
"AWS_ACCESS_KEY_ID not found. Set environment variable AWS_ACCESS_KEY_ID or set aws_access_key_id"
)
assert aws_access_key_id is not None
if aws_secret_access_key is None:
warnings.warn(
"AWS_SECRET_ACCESS_KEY not found. Set environment variable AWS_SECRET_ACCESS_KEY or set aws_secret_access_key"
)
assert aws_secret_access_key is not None
if aws_session_token is None:
warnings.warn(
"AWS_SESSION_TOKEN not found. Set environment variable AWS_SESSION_TOKEN or set aws_session_token"
)
assert aws_session_token is not None
session_kwargs = {
"profile_name": aws_profile,
"region_name": aws_region,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
}
try:
import boto3
session = boto3.Session(**session_kwargs)
except ImportError:
raise ImportError(
"boto3 package not found, install with" "'pip install boto3'"
)
if "bedrock-runtime" in session.get_available_services():
self._client = session.client("bedrock-runtime")
else:
self._client = session.client("bedrock")
@classmethod
def from_credentials(
cls,
model_name: str = Models.TITAN_EMBEDDING,
aws_region: Optional[str] = None,
model: str = Models.TITAN_EMBEDDING,
profile_name: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_profile: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
region_name: Optional[str] = None,
client: Optional[Any] = None,
botocore_session: Optional[Any] = None,
botocore_config: Optional[Any] = None,
additional_kwargs: Optional[Dict[str, Any]] = None,
max_retries: int = 10,
timeout: float = 60.0,
callback_manager: Optional[CallbackManager] = None,
verbose: bool = False,
) -> "BedrockEmbedding":
"""
Instantiate using AWS credentials.
Args:
model_name (str) : Name of the model
aws_access_key_id (str): AWS access key ID
aws_secret_access_key (str): AWS secret access key
aws_session_token (str): AWS session token
aws_region (str): AWS region where the service is located
aws_profile (str): AWS profile, when None, default profile is chosen automatically
Example:
.. code-block:: python
from llama_index.embeddings import BedrockEmbedding
# Define the model name
model_name = "your_model_name"
embeddings = BedrockEmbedding.from_credentials(
model_name,
aws_access_key_id,
aws_secret_access_key,
aws_session_token,
aws_region,
aws_profile,
)
# base class
system_prompt: Optional[str] = None,
messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None,
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
**kwargs: Any,
):
additional_kwargs = additional_kwargs or {}
"""
session_kwargs = {
"profile_name": aws_profile,
"region_name": aws_region,
"profile_name": profile_name,
"region_name": region_name,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
"botocore_session": botocore_session,
}
config = None
try:
import boto3
from botocore.config import Config
config = (
Config(
retries={"max_attempts": max_retries, "mode": "standard"},
connect_timeout=timeout,
read_timeout=timeout,
)
if botocore_config is None
else botocore_config
)
session = boto3.Session(**session_kwargs)
except ImportError:
raise ImportError(
"boto3 package not found, install with" "'pip install boto3'"
)
if "bedrock-runtime" in session.get_available_services():
client = session.client("bedrock-runtime")
# Prior to general availability, custom boto3 wheel files were
# distributed that used the bedrock service to invokeModel.
# This check prevents any services still using those wheel files
# from breaking
if client is not None:
self._client = client
elif "bedrock-runtime" in session.get_available_services():
self._client = session.client("bedrock-runtime", config=config)
else:
client = session.client("bedrock")
return cls(
client=client,
model_name=model_name,
embed_batch_size=embed_batch_size,
self._client = session.client("bedrock", config=config)
super().__init__(
model=model,
max_retries=max_retries,
timeout=timeout,
botocore_config=config,
profile_name=profile_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=region_name,
botocore_session=botocore_session,
additional_kwargs=additional_kwargs,
callback_manager=callback_manager,
verbose=verbose,
system_prompt=system_prompt,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
pydantic_program_mode=pydantic_program_mode,
output_parser=output_parser,
**kwargs,
)
@staticmethod
def list_supported_models() -> Dict[str, List[str]]:
list_models = {}
for provider in PROVIDERS:
list_models[provider.value] = [m.value for m in Models]
return list_models
@classmethod
def class_name(self) -> str:
return "BedrockEmbedding"
def _get_embedding(self, payload: str, type: Literal["text", "query"]) -> Embedding:
if self._client is None:
self.set_credentials()
......@@ -203,12 +173,12 @@ class BedrockEmbedding(BaseEmbedding):
if self._client is None:
raise ValueError("Client not set")
provider = self.model_name.split(".")[0]
provider = self.model.split(".")[0]
request_body = self._get_request_body(provider, payload, type)
response = self._client.invoke_model(
body=request_body,
modelId=self.model_name,
modelId=self.model,
accept="application/json",
contentType="application/json",
)
......@@ -244,8 +214,6 @@ class BedrockEmbedding(BaseEmbedding):
}
"""
if self._verbose:
print("provider: ", provider, PROVIDERS.AMAZON)
if provider == PROVIDERS.AMAZON:
request_body = json.dumps({"inputText": payload})
elif provider == PROVIDERS.COHERE:
......
......@@ -3,6 +3,9 @@ from typing import Any, Callable, Dict, Optional, Sequence
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import (
DEFAULT_TEMPERATURE,
)
from llama_index.core.llms.types import (
ChatMessage,
ChatResponse,
......@@ -41,16 +44,33 @@ class Bedrock(LLM):
profile_name: Optional[str] = Field(
description="The name of aws profile to use. If not given, then the default profile is used."
)
aws_access_key_id: Optional[str] = Field(description="AWS Access Key ID to use")
aws_access_key_id: Optional[str] = Field(
description="AWS Access Key ID to use", exclude=True
)
aws_secret_access_key: Optional[str] = Field(
description="AWS Secret Access Key to use"
description="AWS Secret Access Key to use", exclude=True
)
aws_session_token: Optional[str] = Field(
description="AWS Session Token to use", exclude=True
)
region_name: Optional[str] = Field(
description="AWS region name to use. Uses region configured in AWS CLI if not passed",
exclude=True,
)
aws_session_token: Optional[str] = Field(description="AWS Session Token to use")
aws_region_name: Optional[str] = Field(
description="AWS region name to use. Uses region configured in AWS CLI if not passed"
botocore_session: Optional[Any] = Field(
description="Use this Botocore session instead of creating a new default one.",
exclude=True,
)
botocore_config: Optional[Any] = Field(
description="Custom configuration object to use instead of the default generated one.",
exclude=True,
)
max_retries: int = Field(
default=10, description="The maximum number of API retries."
default=10, description="The maximum number of API retries.", gt=0
)
timeout: float = Field(
default=60.0,
description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.",
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict,
......@@ -64,16 +84,19 @@ class Bedrock(LLM):
def __init__(
self,
model: str,
temperature: Optional[float] = 0.5,
temperature: Optional[float] = DEFAULT_TEMPERATURE,
max_tokens: Optional[int] = 512,
context_size: Optional[int] = None,
profile_name: Optional[str] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
timeout: Optional[float] = None,
region_name: Optional[str] = None,
botocore_session: Optional[Any] = None,
client: Optional[Any] = None,
timeout: Optional[float] = 60.0,
max_retries: Optional[int] = 10,
botocore_config: Optional[Any] = None,
additional_kwargs: Optional[Dict[str, Any]] = None,
callback_manager: Optional[CallbackManager] = None,
system_prompt: Optional[str] = None,
......@@ -81,6 +104,7 @@ class Bedrock(LLM):
completion_to_prompt: Optional[Callable[[str], str]] = None,
pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT,
output_parser: Optional[BaseOutputParser] = None,
**kwargs: Any,
) -> None:
if context_size is None and model not in BEDROCK_FOUNDATION_LLMS:
raise ValueError(
......@@ -88,39 +112,45 @@ class Bedrock(LLM):
"model provided refers to a non-foundation model."
" Please specify the context_size"
)
session_kwargs = {
"profile_name": profile_name,
"region_name": region_name,
"aws_access_key_id": aws_access_key_id,
"aws_secret_access_key": aws_secret_access_key,
"aws_session_token": aws_session_token,
"botocore_session": botocore_session,
}
config = None
try:
import boto3
import botocore
from botocore.config import Config
except Exception as e:
raise ImportError(
"You must install the `boto3` package to use Bedrock."
"Please `pip install boto3`"
) from e
try:
if not profile_name and aws_access_key_id:
session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
config = (
Config(
retries={"max_attempts": max_retries, "mode": "standard"},
connect_timeout=timeout,
read_timeout=timeout,
)
else:
session = boto3.Session(profile_name=profile_name)
# Prior to general availability, custom boto3 wheel files were
# distributed that used the bedrock service to invokeModel.
# This check prevents any services still using those wheel files
# from breaking
if "bedrock-runtime" in session.get_available_services():
self._client = session.client("bedrock-runtime")
else:
self._client = session.client("bedrock")
except botocore.exceptions.NoRegionError as e:
raise ValueError(
"If default region is not set in AWS CLI, you must provide"
" the region_name argument to llama_index.llms.Bedrock"
if botocore_config is None
else botocore_config
)
session = boto3.Session(**session_kwargs)
except ImportError:
raise ImportError(
"boto3 package not found, install with" "'pip install boto3'"
)
# Prior to general availability, custom boto3 wheel files were
# distributed that used the bedrock service to invokeModel.
# This check prevents any services still using those wheel files
# from breaking
if client is not None:
self._client = client
elif "bedrock-runtime" in session.get_available_services():
self._client = session.client("bedrock-runtime", config=config)
else:
self._client = session.client("bedrock", config=config)
additional_kwargs = additional_kwargs or {}
callback_manager = callback_manager or CallbackManager([])
......@@ -138,6 +168,7 @@ class Bedrock(LLM):
profile_name=profile_name,
timeout=timeout,
max_retries=max_retries,
botocore_config=config,
additional_kwargs=additional_kwargs,
callback_manager=callback_manager,
system_prompt=system_prompt,
......
......@@ -34,7 +34,7 @@ class TestBedrockEmbedding(TestCase):
)
bedrock_embedding = BedrockEmbedding(
model_name=Models.TITAN_EMBEDDING,
model=Models.TITAN_EMBEDDING,
client=self.bedrock_client,
)
......@@ -63,7 +63,7 @@ class TestBedrockEmbedding(TestCase):
)
bedrock_embedding = BedrockEmbedding(
model_name=Models.COHERE_EMBED_ENGLISH_V3,
model=Models.COHERE_EMBED_ENGLISH_V3,
client=self.bedrock_client,
)
......
......@@ -58,7 +58,7 @@ class MockStreamCompletionWithRetry:
) -> dict:
assert json.loads(request_body) == {
"inputText": self.expected_prompt,
"textGenerationConfig": {"maxTokenCount": 512, "temperature": 0.5},
"textGenerationConfig": {"maxTokenCount": 512, "temperature": 0.1},
}
return {
"ResponseMetadata": {
......@@ -84,27 +84,27 @@ class MockStreamCompletionWithRetry:
[
(
"amazon.titan-text-express-v1",
'{"inputText": "test prompt", "textGenerationConfig": {"temperature": 0.5, "maxTokenCount": 512}}',
'{"inputText": "test prompt", "textGenerationConfig": {"temperature": 0.1, "maxTokenCount": 512}}',
'{"inputTextTokenCount": 3, "results": [{"tokenCount": 14, "outputText": "\\n\\nThis is indeed a test", "completionReason": "FINISH"}]}',
'{"inputText": "user: test prompt\\nassistant: ", "textGenerationConfig": {"temperature": 0.5, "maxTokenCount": 512}}',
'{"inputText": "user: test prompt\\nassistant: ", "textGenerationConfig": {"temperature": 0.1, "maxTokenCount": 512}}',
),
(
"ai21.j2-grande-instruct",
'{"prompt": "test prompt", "temperature": 0.5, "maxTokens": 512}',
'{"prompt": "test prompt", "temperature": 0.1, "maxTokens": 512}',
'{"completions": [{"data": {"text": "\\n\\nThis is indeed a test"}}]}',
'{"prompt": "user: test prompt\\nassistant: ", "temperature": 0.5, "maxTokens": 512}',
'{"prompt": "user: test prompt\\nassistant: ", "temperature": 0.1, "maxTokens": 512}',
),
(
"cohere.command-text-v14",
'{"prompt": "test prompt", "temperature": 0.5, "max_tokens": 512}',
'{"prompt": "test prompt", "temperature": 0.1, "max_tokens": 512}',
'{"generations": [{"text": "\\n\\nThis is indeed a test"}]}',
'{"prompt": "user: test prompt\\nassistant: ", "temperature": 0.5, "max_tokens": 512}',
'{"prompt": "user: test prompt\\nassistant: ", "temperature": 0.1, "max_tokens": 512}',
),
(
"anthropic.claude-instant-v1",
'{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.5, "max_tokens_to_sample": 512}',
'{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}',
'{"completion": "\\n\\nThis is indeed a test"}',
'{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.5, "max_tokens_to_sample": 512}',
'{"prompt": "\\n\\nHuman: test prompt\\n\\nAssistant: ", "temperature": 0.1, "max_tokens_to_sample": 512}',
),
(
"meta.llama2-13b-chat-v1",
......@@ -112,13 +112,13 @@ class MockStreamCompletionWithRetry:
"honest assistant. Always answer as helpfully as possible and follow "
"ALL given instructions. Do not speculate or make up information. Do "
"not reference any given instructions or context. \\n<</SYS>>\\n\\n "
'test prompt [/INST]", "temperature": 0.5, "max_gen_len": 512}',
'test prompt [/INST]", "temperature": 0.1, "max_gen_len": 512}',
'{"generation": "\\n\\nThis is indeed a test"}',
'{"prompt": "<s> [INST] <<SYS>>\\n You are a helpful, respectful and '
"honest assistant. Always answer as helpfully as possible and follow "
"ALL given instructions. Do not speculate or make up information. Do "
"not reference any given instructions or context. \\n<</SYS>>\\n\\n "
'test prompt [/INST]", "temperature": 0.5, "max_gen_len": 512}',
'test prompt [/INST]", "temperature": 0.1, "max_gen_len": 512}',
),
],
)
......@@ -128,7 +128,7 @@ def test_model_basic(
llm = Bedrock(
model=model,
profile_name=None,
aws_region_name="us-east-1",
region_name="us-east-1",
aws_access_key_id="test",
)
......@@ -168,7 +168,7 @@ def test_model_streaming(monkeypatch: MonkeyPatch) -> None:
llm = Bedrock(
model="amazon.titan-text-express-v1",
profile_name=None,
aws_region_name="us-east-1",
region_name="us-east-1",
aws_access_key_id="test",
)
test_prompt = "test prompt"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment