Skip to content
Snippets Groups Projects
Unverified Commit 2a648073 authored by Bryce Freshcorn's avatar Bryce Freshcorn Committed by GitHub
Browse files

New function for `core.multi_modal_llms.generic_utils` & Anthropic LLM model updates (#16896)

parent 91146042
No related branches found
No related tags found
No related merge requests found
import base64 import base64
import filetype import filetype
import logging import logging
from typing import List, Sequence, Optional from typing import List, Optional, Sequence
import requests import requests
...@@ -91,6 +91,7 @@ def infer_image_mimetype_from_file_path(image_file_path: str) -> str: ...@@ -91,6 +91,7 @@ def infer_image_mimetype_from_file_path(image_file_path: str) -> str:
Returns: Returns:
str: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp. str: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp.
Defaults to `image/jpeg`.
""" """
# Get the file extension # Get the file extension
file_extension = image_file_path.split(".")[-1].lower() file_extension = image_file_path.split(".")[-1].lower()
...@@ -117,6 +118,7 @@ def infer_image_mimetype_from_base64(base64_string: str) -> Optional[str]: ...@@ -117,6 +118,7 @@ def infer_image_mimetype_from_base64(base64_string: str) -> Optional[str]:
Returns: Returns:
Optional[str]: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp. Optional[str]: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp.
`None` if the MIME type cannot be inferred.
""" """
# Decode the base64 string # Decode the base64 string
decoded_data = base64.b64decode(base64_string) decoded_data = base64.b64decode(base64_string)
...@@ -126,3 +128,28 @@ def infer_image_mimetype_from_base64(base64_string: str) -> Optional[str]: ...@@ -126,3 +128,28 @@ def infer_image_mimetype_from_base64(base64_string: str) -> Optional[str]:
# Return the MIME type if detected, otherwise return None # Return the MIME type if detected, otherwise return None
return kind.mime if kind is not None else None return kind.mime if kind is not None else None
def set_base64_and_mimetype_for_image_docs(
image_documents: Sequence[ImageDocument],
) -> Sequence[ImageDocument]:
"""Set the base64 and mimetype fields for the image documents.
Args:
image_documents (Sequence[ImageDocument]): Sequence of ImageDocument objects.
Returns:
Sequence[ImageDocument]: ImageDocuments with base64 and detected mimetypes set.
"""
base64_strings = image_documents_to_base64(image_documents)
for image_doc, base64_str in zip(image_documents, base64_strings):
image_doc.image = base64_str
image_doc.image_mimetype = infer_image_mimetype_from_base64(image_doc.image)
if not image_doc.image_mimetype and image_doc.image_path:
image_doc.image_mimetype = infer_image_mimetype_from_file_path(
image_doc.image_path
)
else:
# Defaults to `image/jpeg` if the mimetype cannot be inferred
image_doc.image_mimetype = "image/jpeg"
return image_documents
...@@ -12,6 +12,7 @@ from llama_index.core.multi_modal_llms.generic_utils import ( ...@@ -12,6 +12,7 @@ from llama_index.core.multi_modal_llms.generic_utils import (
image_documents_to_base64, image_documents_to_base64,
infer_image_mimetype_from_base64, infer_image_mimetype_from_base64,
infer_image_mimetype_from_file_path, infer_image_mimetype_from_file_path,
set_base64_and_mimetype_for_image_docs,
) )
# Expected values # Expected values
...@@ -102,6 +103,7 @@ def test_complete_workflow(): ...@@ -102,6 +103,7 @@ def test_complete_workflow():
def test_infer_image_mimetype_from_base64(): def test_infer_image_mimetype_from_base64():
"""Test inferring image mimetype from base64-encoded data."""
# Create a minimal valid PNG in base64 # Create a minimal valid PNG in base64
base64_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg==" base64_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg=="
...@@ -109,11 +111,12 @@ def test_infer_image_mimetype_from_base64(): ...@@ -109,11 +111,12 @@ def test_infer_image_mimetype_from_base64():
assert result == "image/png" assert result == "image/png"
# Valid, meaningless base64 # Valid, meaningless base64
result = infer_image_mimetype_from_base64("lEQVR4nGMAAQAABQABDQ") result = infer_image_mimetype_from_base64(EXP_BASE64)
assert result is None assert result is None
def test_infer_image_mimetype_from_file_path(): def test_infer_image_mimetype_from_file_path():
"""Test inferring image mimetype from file extensions."""
# JPG/JPEG # JPG/JPEG
assert infer_image_mimetype_from_file_path("image.jpg") == "image/jpeg" assert infer_image_mimetype_from_file_path("image.jpg") == "image/jpeg"
assert infer_image_mimetype_from_file_path("image.jpeg") == "image/jpeg" assert infer_image_mimetype_from_file_path("image.jpeg") == "image/jpeg"
...@@ -127,5 +130,24 @@ def test_infer_image_mimetype_from_file_path(): ...@@ -127,5 +130,24 @@ def test_infer_image_mimetype_from_file_path():
# WEBP # WEBP
assert infer_image_mimetype_from_file_path("image.webp") == "image/webp" assert infer_image_mimetype_from_file_path("image.webp") == "image/webp"
# Catch-all default # Catch-all defaults
assert infer_image_mimetype_from_file_path("image.asf32") == "image/jpeg" assert infer_image_mimetype_from_file_path("image.asf32") == "image/jpeg"
assert infer_image_mimetype_from_file_path("") == "image/jpeg"
def test_set_base64_and_mimetype_for_image_docs():
"""Test setting base64 and mimetype fields for ImageDocument objects."""
image_docs = [
ImageDocument(image=EXP_BASE64),
ImageDocument(image_path="test.asdf"),
]
with patch("requests.get") as mock_get:
mock_get.return_value.content = EXP_BINARY
with patch("builtins.open", mock_open(read_data=EXP_BINARY)):
results = set_base64_and_mimetype_for_image_docs(image_docs)
assert len(results) == 2
assert results[0].image == EXP_BASE64
assert results[0].image_mimetype == "image/jpeg"
assert results[1].image_mimetype == "image/jpeg"
"""
Utility functions for the Anthropic SDK LLM integration.
"""
from typing import Dict, Sequence, Tuple from typing import Dict, Sequence, Tuple
from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole
...@@ -13,33 +17,56 @@ from anthropic.types.beta.prompt_caching import ( ...@@ -13,33 +17,56 @@ from anthropic.types.beta.prompt_caching import (
HUMAN_PREFIX = "\n\nHuman:" HUMAN_PREFIX = "\n\nHuman:"
ASSISTANT_PREFIX = "\n\nAssistant:" ASSISTANT_PREFIX = "\n\nAssistant:"
CLAUDE_MODELS: Dict[str, int] = { # AWS Bedrock Anthropic identifiers
BEDROCK_INFERENCE_PROFILE_CLAUDE_MODELS: Dict[str, int] = {
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,
"anthropic.claude-3-opus-20240229-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 200000,
"anthropic.claude-3-5-haiku-20241022-v1:0": 200000,
}
BEDROCK_CLAUDE_MODELS: Dict[str, int] = {
"anthropic.claude-instant-v1": 100000,
"anthropic.claude-v2": 100000,
"anthropic.claude-v2:1": 200000,
}
# GCP Vertex AI Anthropic identifiers
VERTEX_CLAUDE_MODELS: Dict[str, int] = {
"claude-3-opus@20240229": 200000,
"claude-3-sonnet@20240229": 200000,
"claude-3-haiku@20240307": 200000,
"claude-3-5-sonnet@20240620": 200000,
"claude-3-5-sonnet-v2@20241022": 200000,
"claude-3-5-haiku@20241022": 200000,
}
# Anthropic API/SDK identifiers
ANTHROPIC_MODELS: Dict[str, int] = {
"claude-instant-1": 100000, "claude-instant-1": 100000,
"claude-instant-1.2": 100000, "claude-instant-1.2": 100000,
"claude-2": 100000, "claude-2": 100000,
"claude-2.0": 100000, "claude-2.0": 100000,
"claude-2.1": 200000, "claude-2.1": 200000,
"claude-3-opus-latest": 180000, "claude-3-opus-latest": 200000,
"claude-3-opus-20240229": 180000, "claude-3-opus-20240229": 200000,
"claude-3-opus@20240229": 180000, # Alternate name for Vertex AI "claude-3-sonnet-latest": 200000,
"anthropic.claude-3-opus-20240229-v1:0": 180000, # Alternate name for Bedrock "claude-3-sonnet-20240229": 200000,
"claude-3-sonnet-latest": 180000, "claude-3-haiku-latest": 200000,
"claude-3-sonnet-20240229": 180000, "claude-3-haiku-20240307": 200000,
"claude-3-sonnet@20240229": 180000, # Alternate name for Vertex AI "claude-3-5-sonnet-latest": 200000,
"anthropic.claude-3-sonnet-20240229-v1:0": 180000, # Alternate name for Bedrock "claude-3-5-sonnet-20240620": 200000,
"claude-3-haiku-latest": 180000, "claude-3-5-sonnet-20241022": 200000,
"claude-3-haiku-20240307": 180000, "claude-3-5-haiku-20241022": 200000,
"claude-3-haiku@20240307": 180000, # Alternate name for Vertex AI }
"anthropic.claude-3-haiku-20240307-v1:0": 180000, # Alternate name for Bedrock
"claude-3-5-sonnet-latest": 180000, # All provider Anthropic identifiers
"claude-3-5-sonnet-20240620": 180000, CLAUDE_MODELS: Dict[str, int] = {
"claude-3-5-sonnet-20241022": 180000, **BEDROCK_INFERENCE_PROFILE_CLAUDE_MODELS,
"claude-3-5-sonnet-v2@20241022": 180000, # Alternate name for Vertex AI **BEDROCK_CLAUDE_MODELS,
"anthropic.claude-3-5-sonnet-20241022-v2:0": 180000, # Alternate name for Bedrock **VERTEX_CLAUDE_MODELS,
"claude-3-5-sonnet@20240620": 180000, # Alternate name for Vertex AI **ANTHROPIC_MODELS,
"claude-3-5-haiku-20241022": 180000,
"claude-3-5-haiku@20241022": 180000, # Alternate name for Vertex AI
"anthropic.claude-3-5-haiku-20241022-v1:0": 180000, # Alternate name for Bedrock
} }
...@@ -48,6 +75,19 @@ def is_function_calling_model(modelname: str) -> bool: ...@@ -48,6 +75,19 @@ def is_function_calling_model(modelname: str) -> bool:
def anthropic_modelname_to_contextsize(modelname: str) -> int: def anthropic_modelname_to_contextsize(modelname: str) -> int:
"""Get the context size for an Anthropic model.
Args:
modelname (str): Anthropic model name.
Returns:
int: Context size for the specific model.
"""
for model, context_size in BEDROCK_INFERENCE_PROFILE_CLAUDE_MODELS.items():
# Only US & EU inference profiles are currently supported by AWS
CLAUDE_MODELS[f"us.{model}"] = context_size
CLAUDE_MODELS[f"eu.{model}"] = context_size
if modelname not in CLAUDE_MODELS: if modelname not in CLAUDE_MODELS:
raise ValueError( raise ValueError(
f"Unknown model: {modelname}. Please provide a valid Anthropic model name." f"Unknown model: {modelname}. Please provide a valid Anthropic model name."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment