From 578e9697fe5839bb5b8c002c445abd4596ccaa49 Mon Sep 17 00:00:00 2001
From: Bryce Freshcorn <26725654+brycecf@users.noreply.github.com>
Date: Fri, 8 Nov 2024 18:32:18 -0500
Subject: [PATCH] Added docstrings and unit tests for `core.multimodal`
 (#16872)

---
 .../core/multi_modal_llms/generic_utils.py    | 107 ++++++++++++--
 llama-index-core/poetry.lock                  |  15 +-
 llama-index-core/pyproject.toml               |   1 +
 llama-index-core/tests/multi_modal_llms/BUILD |   3 +
 .../test_base_multi_modal_llm_metadata.py     |  30 ++++
 .../multi_modal_llms/test_generic_utils.py    | 131 ++++++++++++++++++
 6 files changed, 270 insertions(+), 17 deletions(-)
 create mode 100644 llama-index-core/tests/multi_modal_llms/BUILD
 create mode 100644 llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py
 create mode 100644 llama-index-core/tests/multi_modal_llms/test_generic_utils.py

diff --git a/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py b/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py
index 9f1d5b843b..8cba775c0e 100644
--- a/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py
+++ b/llama-index-core/llama_index/core/multi_modal_llms/generic_utils.py
@@ -1,6 +1,7 @@
 import base64
+import filetype
 import logging
-from typing import List, Sequence
+from typing import List, Sequence, Optional
 
 import requests
 
@@ -10,42 +11,118 @@ logger = logging.getLogger(__name__)
 
 
 def load_image_urls(image_urls: List[str]) -> List[ImageDocument]:
-    # load remote image urls into image documents
-    image_documents = []
-    for i in range(len(image_urls)):
-        new_image_document = ImageDocument(image_url=image_urls[i])
-        image_documents.append(new_image_document)
-    return image_documents
+    """Convert a list of image URLs into ImageDocument objects.
+
+    Args:
+        image_urls (List[str]): List of strings containing valid image URLs.
+
+    Returns:
+        List[ImageDocument]: List of ImageDocument objects.
+    """
+    return [ImageDocument(image_url=url) for url in image_urls]
 
 
-# Function to encode the image to base64 content
 def encode_image(image_path: str) -> str:
+    """Create base64 representation of an image.
+
+    Args:
+        image_path (str): Path to the image file
+
+    Returns:
+        str: Base64 encoded string of the image
+
+    Raises:
+        FileNotFoundError: If the `image_path` doesn't exist.
+        IOError: If there's an error reading the file.
+    """
     with open(image_path, "rb") as image_file:
         return base64.b64encode(image_file.read()).decode("utf-8")
 
 
-# Supporting Ollama like Multi-Modal images base64 encoding
 def image_documents_to_base64(
     image_documents: Sequence[ImageDocument],
 ) -> List[str]:
+    """Convert ImageDocument objects to base64-encoded strings.
+
+    Args:
+        image_documents (Sequence[ImageDocument]: Sequence of
+            ImageDocument objects
+
+    Returns:
+        List[str]: List of base64-encoded image strings
+    """
     image_encodings = []
-    # encode image documents to base64
+
+    # Encode image documents to base64
     for image_document in image_documents:
-        if image_document.image:
+        if image_document.image:  # This field is already base64-encoded
             image_encodings.append(image_document.image)
-        elif image_document.image_path:
+        elif (
+            image_document.image_path
+        ):  # This field is a path to the image, which is then encoded.
             image_encodings.append(encode_image(image_document.image_path))
         elif (
             "file_path" in image_document.metadata
             and image_document.metadata["file_path"] != ""
-        ):
+        ):  # Alternative path to the image, which is then encoded.
             image_encodings.append(encode_image(image_document.metadata["file_path"]))
-        elif image_document.image_url:
+        elif image_document.image_url:  # Image can also be pulled from the URL.
             response = requests.get(image_document.image_url)
             try:
                 image_encodings.append(
                     base64.b64encode(response.content).decode("utf-8")
                 )
             except Exception as e:
-                logger.warning(f"Cannot encode the image url-> {e}")
+                logger.warning(f"Cannot encode the image pulled from URL -> {e}")
     return image_encodings
+
+
+def infer_image_mimetype_from_file_path(image_file_path: str) -> str:
+    """Infer the MIME of an image file based on its file extension.
+
+    Currently only supports the following types of images:
+        * image/jpeg
+        * image/png
+        * image/gif
+        * image/webp
+
+    Args:
+        image_file_path (str): Path to the image file.
+
+    Returns:
+        str: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp.
+    """
+    # Get the file extension
+    file_extension = image_file_path.split(".")[-1].lower()
+
+    # Map file extensions to mimetypes
+    if file_extension == "jpg" or file_extension == "jpeg":
+        return "image/jpeg"
+    elif file_extension == "png":
+        return "image/png"
+    elif file_extension == "gif":
+        return "image/gif"
+    elif file_extension == "webp":
+        return "image/webp"
+
+    # If the file extension is not recognized
+    return "image/jpeg"
+
+
+def infer_image_mimetype_from_base64(base64_string: str) -> Optional[str]:
+    """Infer the MIME of an image from the base64 encoding.
+
+    Args:
+        base64_string (str): Base64-encoded string of the image.
+
+    Returns:
+        Optional[str]: MIME type of the image: image/jpeg, image/png, image/gif, or image/webp.
+    """
+    # Decode the base64 string
+    decoded_data = base64.b64decode(base64_string)
+
+    # Use filetype to guess the MIME type
+    kind = filetype.guess(decoded_data)
+
+    # Return the MIME type if detected, otherwise return None
+    return kind.mime if kind is not None else None
diff --git a/llama-index-core/poetry.lock b/llama-index-core/poetry.lock
index 277ca7c3ea..5342731dfb 100644
--- a/llama-index-core/poetry.lock
+++ b/llama-index-core/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
 
 [[package]]
 name = "aiohappyeyeballs"
@@ -1239,6 +1239,17 @@ docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1
 testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"]
 typing = ["typing-extensions (>=4.8)"]
 
+[[package]]
+name = "filetype"
+version = "1.2.0"
+description = "Infer file type and MIME type of any file/buffer. No external dependencies."
+optional = false
+python-versions = "*"
+files = [
+    {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"},
+    {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"},
+]
+
 [[package]]
 name = "fqdn"
 version = "1.5.1"
@@ -5775,4 +5786,4 @@ type = ["pytest-mypy"]
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.8.1,<4.0"
-content-hash = "c8f7cb4bf41616a6010eaf8acb80fc7b339cf8a3e2b12a7da15880b83a05e6ff"
+content-hash = "a53c1909592a3c732f0958bfdf25e38e611cec90ea2a64e16ed8becd570c2bca"
diff --git a/llama-index-core/pyproject.toml b/llama-index-core/pyproject.toml
index b329818e88..ca090b1a3d 100644
--- a/llama-index-core/pyproject.toml
+++ b/llama-index-core/pyproject.toml
@@ -71,6 +71,7 @@ pillow = ">=9.0.0"
 PyYAML = ">=6.0.1"
 wrapt = "*"
 pydantic = ">=2.7.0,<3.0.0"
+filetype = "^1.2.0"  # Used for multi-modal MIME utils
 
 [tool.poetry.group.dev.dependencies]
 black = {extras = ["jupyter"], version = ">=23.7.0,<=24.3.0"}
diff --git a/llama-index-core/tests/multi_modal_llms/BUILD b/llama-index-core/tests/multi_modal_llms/BUILD
new file mode 100644
index 0000000000..57341b1358
--- /dev/null
+++ b/llama-index-core/tests/multi_modal_llms/BUILD
@@ -0,0 +1,3 @@
+python_tests(
+    name="tests",
+)
diff --git a/llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py b/llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py
new file mode 100644
index 0000000000..7850c8d50d
--- /dev/null
+++ b/llama-index-core/tests/multi_modal_llms/test_base_multi_modal_llm_metadata.py
@@ -0,0 +1,30 @@
+"""Unit tests for `llama_index.core.multi_modal_llms.MultiModalLLMMetadata`."""
+
+from llama_index.core.multi_modal_llms import MultiModalLLMMetadata
+
+
+class TestMultiModalLLMMetadata:
+    def test_default_values(self):
+        metadata = MultiModalLLMMetadata()
+        assert metadata.model_name == "unknown"
+        assert metadata.is_chat_model is False
+        assert metadata.is_function_calling_model is False
+        assert metadata.context_window is not None
+        assert metadata.num_output is not None
+        assert metadata.num_input_files is not None
+
+    def test_custom_values(self):
+        metadata = MultiModalLLMMetadata(
+            model_name="test-model",
+            context_window=2048,
+            num_output=512,
+            num_input_files=5,
+            is_function_calling_model=True,
+            is_chat_model=True,
+        )
+        assert metadata.model_name == "test-model"
+        assert metadata.context_window == 2048
+        assert metadata.num_output == 512
+        assert metadata.num_input_files == 5
+        assert metadata.is_function_calling_model is True
+        assert metadata.is_chat_model is True
diff --git a/llama-index-core/tests/multi_modal_llms/test_generic_utils.py b/llama-index-core/tests/multi_modal_llms/test_generic_utils.py
new file mode 100644
index 0000000000..20759c9049
--- /dev/null
+++ b/llama-index-core/tests/multi_modal_llms/test_generic_utils.py
@@ -0,0 +1,131 @@
+"""Unit tests for `llama_index.core.multi_modal_llms.generic_utils`."""
+
+import pytest
+import base64
+from unittest.mock import mock_open, patch, MagicMock
+
+from llama_index.core.schema import ImageDocument
+
+from llama_index.core.multi_modal_llms.generic_utils import (
+    load_image_urls,
+    encode_image,
+    image_documents_to_base64,
+    infer_image_mimetype_from_base64,
+    infer_image_mimetype_from_file_path,
+)
+
+# Expected values
+EXP_IMAGE_URLS = ["http://example.com/image1.jpg"]
+EXP_BASE64 = "SGVsbG8gV29ybGQ="  # "Hello World" in base64
+EXP_BINARY = b"Hello World"
+
+
+@pytest.fixture()
+def mock_successful_response():
+    mock_response = MagicMock()
+    mock_response.content = EXP_BINARY
+    return mock_response
+
+
+def test_load_image_urls():
+    """Test loading image URLs into ImageDocument objects."""
+    result = load_image_urls(EXP_IMAGE_URLS)
+
+    assert len(result) == len(EXP_IMAGE_URLS)
+    assert all(isinstance(doc, ImageDocument) for doc in result)
+    assert all(doc.image_url == url for doc, url in zip(result, EXP_IMAGE_URLS))
+
+
+def test_load_image_urls_with_empty_list():
+    """Test loading an empty list of URLs."""
+    result = load_image_urls([])
+    assert result == []
+
+
+def test_encode_image():
+    """Test successful image encoding."""
+    with patch("builtins.open", mock_open(read_data=EXP_BINARY)):
+        result = encode_image("fake_image.jpg")
+
+    assert result == EXP_BASE64
+
+
+def test_image_documents_to_base64_multiple_sources():
+    """Test converting multiple ImageDocuments with different source types."""
+    documents = [
+        ImageDocument(image=EXP_BASE64),
+        ImageDocument(image_path="test.jpg"),
+        ImageDocument(metadata={"file_path": "test.jpg"}),
+        ImageDocument(image_url=EXP_IMAGE_URLS[0]),
+    ]
+    with patch("requests.get") as mock_get:
+        mock_get.return_value.content = EXP_BINARY
+        with patch("builtins.open", mock_open(read_data=EXP_BINARY)):
+            result = image_documents_to_base64(documents)
+
+    assert len(result) == 4
+    assert all(encoding == EXP_BASE64 for encoding in result)
+
+
+def test_image_documents_to_base64_failed_url():
+    """Test handling of failed URL requests."""
+    document = ImageDocument(image_url=EXP_IMAGE_URLS[0])
+    with patch("requests.get"):
+        result = image_documents_to_base64([document])
+
+    assert result == []
+
+
+def test_image_documents_to_base64_empty_sequence():
+    """Test handling of empty sequence of documents."""
+    result = image_documents_to_base64([])
+    assert result == []
+
+
+def test_image_documents_to_base64_invalid_metadata():
+    """Test handling of document with invalid metadata path."""
+    document = ImageDocument(metadata={"file_path": ""})
+    result = image_documents_to_base64([document])
+    assert result == []
+
+
+def test_complete_workflow():
+    """Test the complete workflow from URL to base64 encoding."""
+    documents = load_image_urls(EXP_IMAGE_URLS)
+    with patch("requests.get") as mock_get:
+        mock_get.return_value.content = EXP_BINARY
+        result = image_documents_to_base64(documents)
+
+    assert len(result) == len(EXP_IMAGE_URLS)
+    assert isinstance(result[0], str)
+    assert base64.b64decode(result[0]) == EXP_BINARY
+
+
+def test_infer_image_mimetype_from_base64():
+    # Create a minimal valid PNG in base64
+    base64_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAACklEQVR4nGMAAQAABQABDQottAAAAABJRU5ErkJggg=="
+
+    result = infer_image_mimetype_from_base64(base64_png)
+    assert result == "image/png"
+
+    # Valid, meaningless base64
+    result = infer_image_mimetype_from_base64("lEQVR4nGMAAQAABQABDQ")
+    assert result is None
+
+
+def test_infer_image_mimetype_from_file_path():
+    # JPG/JPEG
+    assert infer_image_mimetype_from_file_path("image.jpg") == "image/jpeg"
+    assert infer_image_mimetype_from_file_path("image.jpeg") == "image/jpeg"
+
+    # PNG
+    assert infer_image_mimetype_from_file_path("image.png") == "image/png"
+
+    # GIF
+    assert infer_image_mimetype_from_file_path("image.gif") == "image/gif"
+
+    # WEBP
+    assert infer_image_mimetype_from_file_path("image.webp") == "image/webp"
+
+    # Catch-all default
+    assert infer_image_mimetype_from_file_path("image.asf32") == "image/jpeg"
-- 
GitLab