From 578e9697fe5839bb5b8c002c445abd4596ccaa49 Mon Sep 17 00:00:00 2001 From: Bryce Freshcorn <26725654+brycecf@users.noreply.github.com> Date: Fri, 8 Nov 2024 18:32:18 -0500 Subject: [PATCH] Added docstrings and unit tests for `core.multimodal` (#16872) --- .../core/multi_modal_llms/generic_utils.py | 107 ++++++++++++-- llama-index-core/poetry.lock | 15 +- llama-index-core/pyproject.toml | 1 + llama-index-core/tests/multi_modal_llms/BUILD | 3 + .../test_base_multi_modal_llm_metadata.py | 30 ++++ .../multi_modal_llms/test_generic_utils.py | 131 ++++++++++++++++++ 6 files changed, 270 insertions(+), 17 deletions(-) create mode 100644 llama-index-core/tests/multi_modal_llms/BUILD create mode 100644 llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py create mode 100644 llama-index-core/tests/multi_modal_llms/test_generic_utils.py diff --git a/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py b/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py index 9f1d5b843b..8cba775c0e 100644 --- a/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py +++ b/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py @@ -1,6 +1,7 @@ import base64 +import filetype import logging -from typing import List, Sequence +from typing import List, Sequence, Optional import requests @@ -10,42 +11,118 @@ logger = logging.getLogger(__name__) def load_image_urls(image_urls: List[str]) -> List[ImageDocument]: - # load remote image urls into image documents - image_documents = [] - for i in range(len(image_urls)): - new_image_document = ImageDocument(image_url=image_urls[i]) - image_documents.append(new_image_document) - return image_documents + """Convert a list of image URLs into ImageDocument objects. + + Args: + image_urls (List[str]): List of strings containing valid image URLs. + + Returns: + List[ImageDocument]: List of ImageDocument objects. + """ + return [ImageDocument(image_url=url) for url in image_urls] -# Function to encode the image to base64 content def encode_image(image_path: str) -> str: + """Create base64 representation of an image. + + Args: + image_path (str): Path to the image file + + Returns: + str: Base64 encoded string of the image + + Raises: + FileNotFoundError: If the `image_path` doesn't exist. + IOError: If there's an error reading the file. + """ with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") -# Supporting Ollama like Multi-Modal images base64 encoding def image_documents_to_base64( image_documents: Sequence[ImageDocument], ) -> List[str]: + """Convert ImageDocument objects to base64-encoded strings. + + Args: + image_documents (Sequence[ImageDocument]: Sequence of + ImageDocument objects + + Returns: + List[str]: List of base64-encoded image strings + """ image_encodings = [] - # encode image documents to base64 + + # Encode image documents to base64 for image_document in image_documents: - if image_document.image: + if image_document.image: # This field is already base64-encoded image_encodings.append(image_document.image) - elif image_document.image_path: + elif ( + image_document.image_path + ): # This field is a path to the image, which is then encoded. image_encodings.append(encode_image(image_document.image_path)) elif ( "file_path" in image_document.metadata and image_document.metadata["file_path"] != "" - ): + ): # Alternative path to the image, which is then encoded. image_encodings.append(encode_image(image_document.metadata["file_path"])) - elif image_document.image_url: + elif image_document.image_url: # Image can also be pulled from the URL. response = requests.get(image_document.image_url) try: image_encodings.append( base64.b64encode(response.content).decode("utf-8") ) except Exception as e: - logger.warning(f"Cannot encode the image url-> {e}") + logger.warning(f"Cannot encode the image pulled from URL -> {e}") return image_encodings + + +def infer_image_mimetype_from_file_path(image_file_path: str) -> str: + """Infer the MIME of an image file based on its file extension. + + Currently only supports the following types of images: + * image/jpeg + * image/png + * image/gif + * image/webp + + Args: + image_file_path (str): Path to the image file. + + Returns: + str: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp. + """ + # Get the file extension + file_extension = image_file_path.split(".")[-1].lower() + + # Map file extensions to mimetypes + if file_extension == "jpg" or file_extension == "jpeg": + return "image/jpeg" + elif file_extension == "png": + return "image/png" + elif file_extension == "gif": + return "image/gif" + elif file_extension == "webp": + return "image/webp" + + # If the file extension is not recognized + return "image/jpeg" + + +def infer_image_mimetype_from_base64(base64_string: str) -> Optional[str]: + """Infer the MIME of an image from the base64 encoding. + + Args: + base64_string (str): Base64-encoded string of the image. + + Returns: + Optional[str]: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp. + """ + # Decode the base64 string + decoded_data = base64.b64decode(base64_string) + + # Use filetype to guess the MIME type + kind = filetype.guess(decoded_data) + + # Return the MIME type if detected, otherwise return None + return kind.mime if kind is not None else None diff --git a/llama-index-core/poetry.lock b/llama-index-core/poetry.lock index 277ca7c3ea..5342731dfb 100644 --- a/llama-index-core/poetry.lock +++ b/llama-index-core/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1239,6 +1239,17 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] typing = ["typing-extensions (>=4.8)"] +[[package]] +name = "filetype" +version = "1.2.0" +description = "Infer file type and MIME type of any file/buffer. No external dependencies." +optional = false +python-versions = "*" +files = [ + {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"}, + {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -5775,4 +5786,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "c8f7cb4bf41616a6010eaf8acb80fc7b339cf8a3e2b12a7da15880b83a05e6ff" +content-hash = "a53c1909592a3c732f0958bfdf25e38e611cec90ea2a64e16ed8becd570c2bca" diff --git a/llama-index-core/pyproject.toml b/llama-index-core/pyproject.toml index b329818e88..ca090b1a3d 100644 --- a/llama-index-core/pyproject.toml +++ b/llama-index-core/pyproject.toml @@ -71,6 +71,7 @@ pillow = ">=9.0.0" PyYAML = ">=6.0.1" wrapt = "*" pydantic = ">=2.7.0,<3.0.0" +filetype = "^1.2.0" # Used for multi-modal MIME utils [tool.poetry.group.dev.dependencies] black = {extras = ["jupyter"], version = ">=23.7.0,<=24.3.0"} diff --git a/llama-index-core/tests/multi_modal_llms/BUILD b/llama-index-core/tests/multi_modal_llms/BUILD new file mode 100644 index 0000000000..57341b1358 --- /dev/null +++ b/llama-index-core/tests/multi_modal_llms/BUILD @@ -0,0 +1,3 @@ +python_tests( + name="tests", +) diff --git a/llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py b/llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py new file mode 100644 index 0000000000..7850c8d50d --- /dev/null +++ b/llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py @@ -0,0 +1,30 @@ +"""Unit tests for `llama_index.core.multi_modal_llms.MultiModalLLMMetadata`.""" + +from llama_index.core.multi_modal_llms import MultiModalLLMMetadata + + +class TestMultiModalLLMMetadata: + def test_default_values(self): + metadata = MultiModalLLMMetadata() + assert metadata.model_name == "unknown" + assert metadata.is_chat_model is False + assert metadata.is_function_calling_model is False + assert metadata.context_window is not None + assert metadata.num_output is not None + assert metadata.num_input_files is not None + + def test_custom_values(self): + metadata = MultiModalLLMMetadata( + model_name="test-model", + context_window=2048, + num_output=512, + num_input_files=5, + is_function_calling_model=True, + is_chat_model=True, + ) + assert metadata.model_name == "test-model" + assert metadata.context_window == 2048 + assert metadata.num_output == 512 + assert metadata.num_input_files == 5 + assert metadata.is_function_calling_model is True + assert metadata.is_chat_model is True diff --git a/llama-index-core/tests/multi_modal_llms/test_generic_utils.py b/llama-index-core/tests/multi_modal_llms/test_generic_utils.py new file mode 100644 index 0000000000..20759c9049 --- /dev/null +++ b/llama-index-core/tests/multi_modal_llms/test_generic_utils.py @@ -0,0 +1,131 @@ +"""Unit tests for `llama_index.core.multi_modal_llms.generic_utils`.""" + +import pytest +import base64 +from unittest.mock import mock_open, patch, MagicMock + +from llama_index.core.schema import ImageDocument + +from llama_index.core.multi_modal_llms.generic_utils import ( + load_image_urls, + encode_image, + image_documents_to_base64, + infer_image_mimetype_from_base64, + infer_image_mimetype_from_file_path, +) + +# Expected values +EXP_IMAGE_URLS = ["http://example.com/image1.jpg"] +EXP_BASE64 = "SGVsbG8gV29ybGQ=" # "Hello World" in base64 +EXP_BINARY = b"Hello World" + + +@pytest.fixture() +def mock_successful_response(): + mock_response = MagicMock() + mock_response.content = EXP_BINARY + return mock_response + + +def test_load_image_urls(): + """Test loading image URLs into ImageDocument objects.""" + result = load_image_urls(EXP_IMAGE_URLS) + + assert len(result) == len(EXP_IMAGE_URLS) + assert all(isinstance(doc, ImageDocument) for doc in result) + assert all(doc.image_url == url for doc, url in zip(result, EXP_IMAGE_URLS)) + + +def test_load_image_urls_with_empty_list(): + """Test loading an empty list of URLs.""" + result = load_image_urls([]) + assert result == [] + + +def test_encode_image(): + """Test successful image encoding.""" + with patch("builtins.open", mock_open(read_data=EXP_BINARY)): + result = encode_image("fake_image.jpg") + + assert result == EXP_BASE64 + + +def test_image_documents_to_base64_multiple_sources(): + """Test converting multiple ImageDocuments with different source types.""" + documents = [ + ImageDocument(image=EXP_BASE64), + ImageDocument(image_path="test.jpg"), + ImageDocument(metadata={"file_path": "test.jpg"}), + ImageDocument(image_url=EXP_IMAGE_URLS[0]), + ] + with patch("requests.get") as mock_get: + mock_get.return_value.content = EXP_BINARY + with patch("builtins.open", mock_open(read_data=EXP_BINARY)): + result = image_documents_to_base64(documents) + + assert len(result) == 4 + assert all(encoding == EXP_BASE64 for encoding in result) + + +def test_image_documents_to_base64_failed_url(): + """Test handling of failed URL requests.""" + document = ImageDocument(image_url=EXP_IMAGE_URLS[0]) + with patch("requests.get"): + result = image_documents_to_base64([document]) + + assert result == [] + + +def test_image_documents_to_base64_empty_sequence(): + """Test handling of empty sequence of documents.""" + result = image_documents_to_base64([]) + assert result == [] + + +def test_image_documents_to_base64_invalid_metadata(): + """Test handling of document with invalid metadata path.""" + document = ImageDocument(metadata={"file_path": ""}) + result = image_documents_to_base64([document]) + assert result == [] + + +def test_complete_workflow(): + """Test the complete workflow from URL to base64 encoding.""" + documents = load_image_urls(EXP_IMAGE_URLS) + with patch("requests.get") as mock_get: + mock_get.return_value.content = EXP_BINARY + result = image_documents_to_base64(documents) + + assert len(result) == len(EXP_IMAGE_URLS) + assert isinstance(result[0], str) + assert base64.b64decode(result[0]) == EXP_BINARY + + +def test_infer_image_mimetype_from_base64(): + # Create a minimal valid PNG in base64 + base64_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg==" + + result = infer_image_mimetype_from_base64(base64_png) + assert result == "image/png" + + # Valid, meaningless base64 + result = infer_image_mimetype_from_base64("lEQVR4nGMAAQAABQABDQ") + assert result is None + + +def test_infer_image_mimetype_from_file_path(): + # JPG/JPEG + assert infer_image_mimetype_from_file_path("image.jpg") == "image/jpeg" + assert infer_image_mimetype_from_file_path("image.jpeg") == "image/jpeg" + + # PNG + assert infer_image_mimetype_from_file_path("image.png") == "image/png" + + # GIF + assert infer_image_mimetype_from_file_path("image.gif") == "image/gif" + + # WEBP + assert infer_image_mimetype_from_file_path("image.webp") == "image/webp" + + # Catch-all default + assert infer_image_mimetype_from_file_path("image.asf32") == "image/jpeg" -- GitLab