diff --git a/llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/image_vision_llm/base.py b/llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/image_vision_llm/base.py index af9e162e4daa0e076c43c2e22224f89202eb64c9..643ab0b0da969fe6fc75b9b3b749ba031a7c8045 100644 --- a/llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/image_vision_llm/base.py +++ b/llama-index-integrations/readers/llama-index-readers-file/llama_index/readers/file/image_vision_llm/base.py @@ -33,8 +33,15 @@ class ImageVisionLLMReader(BaseReader): "`pip install torch transformers sentencepiece Pillow`" ) + self._torch = torch + self._torch_imported = True + device = infer_torch_device() - dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + dtype = ( + self._torch.float16 + if self._torch.cuda.is_available() + else self._torch.float32 + ) processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-opt-2.7b", torch_dtype=dtype @@ -46,6 +53,9 @@ class ImageVisionLLMReader(BaseReader): "dtype": dtype, } + # Try to import PyTorch in order to run inference efficiently. + self._import_torch() + self._parser_config = parser_config self._keep_image = keep_image self._prompt = prompt @@ -79,7 +89,16 @@ class ImageVisionLLMReader(BaseReader): inputs = processor(image, self._prompt, return_tensors="pt").to(device, dtype) - out = model.generate(**inputs) + if self._torch_imported: + # Gradients are not needed during inference. If PyTorch is + # installed, we can instruct it to not track the gradients. + # This reduces GPU memory usage and improves inference efficiency. + with self._torch.no_grad(): + out = model.generate(**inputs) + else: + # Fallback to less efficient behavior if PyTorch is not installed. + out = model.generate(**inputs) + text_str = processor.decode(out[0], skip_special_tokens=True) return [ @@ -90,3 +109,14 @@ class ImageVisionLLMReader(BaseReader): metadata=extra_info or {}, ) ] + + def _import_torch(self) -> None: + self._torch = None + + try: + import torch + + self._torch = torch + self._torch_imported = True + except ImportError: + self._torch_imported = False diff --git a/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml b/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml index c1856ea73b5631c0cea4f0fe05fa3fbe3d90329a..f044380cb44c7919a2281a1846334df1da98c83c 100644 --- a/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml +++ b/llama-index-integrations/readers/llama-index-readers-file/pyproject.toml @@ -51,7 +51,7 @@ license = "MIT" maintainers = ["FarisHijazi", "Haowjy", "ephe-meral", "hursh-desai", "iamarunbrahma", "jon-chuang", "mmaatouk", "ravi03071991", "sangwongenip", "thejessezhang"] name = "llama-index-readers-file" readme = "README.md" -version = "0.4.5" +version = "0.4.6" [tool.poetry.dependencies] python = ">=3.9,<4.0" diff --git a/llama-index-integrations/readers/llama-index-readers-file/tests/test_image_vision_llm.py b/llama-index-integrations/readers/llama-index-readers-file/tests/test_image_vision_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d7ac00ef54e96d836ea7f37767eee3403f4522 --- /dev/null +++ b/llama-index-integrations/readers/llama-index-readers-file/tests/test_image_vision_llm.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +from contextlib import contextmanager + +import builtins + +import pytest +from unittest import mock + +from typing import Dict, List +from types import ModuleType + +try: + import torch + import numpy as np + from PIL import Image +except ImportError: + torch = None + np = None + Image = None + +from llama_index.readers.file.image_vision_llm.base import ImageVisionLLMReader + + +# Fixture to create a temporary 16x16 pixel image file +@pytest.fixture() +def test_16x16_png_image_file(tmp_path) -> str: + # Create a checkerboard pattern (alternating 0 and 255) + img_array = np.zeros((16, 16), dtype=np.uint8) + img_array[::2, ::2] = 255 # Set even rows and columns to white + img_array[1::2, 1::2] = 255 # Set odd rows and columns to white + + # Convert numpy array to PIL Image + img = Image.fromarray(img_array, mode="L") # 'L' mode is for grayscale + + file_path = tmp_path / "test_image_16x16.png" + img.save(file_path) + + return file_path + + +class TokenizerFake: + """ + This double fakes the `Blip2Processor` tokenizer object so as to + avoid having to instantiate the actual tokenizer for these tests. + """ + + def __call__(self, img, prompt, return_tensors) -> TokenizerFake: + """This is just a stub for the purposes of the test, + so we just return the instance itself. + """ + return self + + def to(self, device, dtype) -> Dict[str, list]: + """ + The output is the tokenized version of the prompt + "Question: describe what you see in this image. Answer:" + It should be of type `transformers.image_processing_base.BatchFeature` + with `torch.Tensor` typed values for `"input_ids"`, `"attention_mask"`, + and `"pixel_values"` keys. However, we will fake them as lists of + integers where values are needed (`None` elsewhere) in order + to not require `torch` or `numpy` imports. + """ + return { + "input_ids": [ + [2, 45641, 35, 6190, 99, 47, 192, 11, 42, 2274, 4, 31652, 35] + ], + "attention_mask": [[None]], + "pixel_values": [[[[None]]]], + } + + def decode( + self, tokens: Dict[str, List[int]], skip_special_tokens: bool = True + ) -> str: + """ + We return the known expected decoded response for the + `test_16x16_png_image_file` fixture and the default prompt + of the `ImageVisionLLMReader` class. + """ + return "Question: describe what you see in this image. Answer: a black and white checkered pattern" + + +class ModelFake: + """ + This double fakes the `Blip2ForConditionalGeneration` model object + in order to avoid having to download checkpoints for these tests. + """ + + def generate(self, **kwargs) -> list: + """ + The output is the tokenized version of the prompt + "Question: describe what you see in this image. \ + Answer: a black and white checkered pattern" + It should be of type `torch.Tensor`. However, we will fake it as a + list of integers order to not require `torch` or `numpy` imports. + """ + return [ + [ + 2, + 45641, + 35, + 6190, + 99, + 47, + 192, + 11, + 42, + 2274, + 4, + 31652, + 35, + 10, + 909, + 8, + 1104, + 5851, + 438, + 20093, + 6184, + 50118, + ] + ] + + def to(self, device) -> None: + """This is just a dummy method for the purposes of the test (it + needs to be defined, but is not used). Hence, we return nothing. + """ + + +@contextmanager +def _get_custom_import(torch_installed: bool): + """ + Simulate absence of PyTorch installation depending on the input flag. + + Args: + torch_installed (bool): Flag indicating whether or not PyTorch is installed. + + Returns: + Generator: Parametrized `_custom_import()` function. + """ + # Store the original __import__ function + original_import = builtins.__import__ + + def _custom_import(module_name: str, *args, **kwargs) -> ModuleType: + """ + If `torch_installed` is False, act as if PyTorch is not installed. + """ + if module_name == "torch" and not torch_installed: + raise ImportError('No module named "torch.') + + return original_import(module_name, *args, **kwargs) + + try: + # Replace the built-in __import__ function + builtins.__import__ = _custom_import + + yield + except Exception: + # Restore the original import function + builtins.__import__ = original_import + + raise + finally: + # Restore the original import function + builtins.__import__ = original_import + + +@pytest.mark.skipif( + Image is None, + reason="PIL not installed", +) +@pytest.mark.parametrize( + "torch_installed", + [ + pytest.param( + False, + id="torch_not_installed", + ), + pytest.param( + True, + id="torch_installed", + ), + ], +) +def test_image_vision_llm_reader_load_data_with_parser_config( + torch_installed: bool, test_16x16_png_image_file: str +): + """ + We use doubles (mocks and fakes) for the model and the tokenizer objects + in order to avoid having to download checkpoints as part of tests, while + still covering all essential `ImageVisionLLMReader` class functionality. + """ + with mock.patch( + "transformers.Blip2ForConditionalGeneration.from_pretrained", + return_value=ModelFake(), + ) as model, mock.patch( + "transformers.Blip2Processor.from_pretrained", + return_value=TokenizerFake(), + ) as processor: + parser_config = { + "processor": processor(), + "model": model(), + "device": "auto", # not used (placeholder) + "dtype": float, # not used (placeholder) + } + + if torch_installed: + image_vision_llm_reader = ImageVisionLLMReader( + parser_config=parser_config, keep_image=True + ) + assert image_vision_llm_reader._torch_imported + else: + with _get_custom_import(torch_installed=False): + image_vision_llm_reader = ImageVisionLLMReader( + parser_config=parser_config, keep_image=True + ) + assert not image_vision_llm_reader._torch_imported + + result = image_vision_llm_reader.load_data(file=test_16x16_png_image_file)[0] + assert ( + result.text + == "Question: describe what you see in this image. Answer: a black and white checkered pattern" + ) + + +@pytest.mark.skipif( + Image is None, + reason="PIL not installed", +) +@pytest.mark.parametrize( + "torch_installed", + [ + pytest.param( + False, + id="torch_not_installed", + ), + pytest.param( + True, + id="torch_installed", + ), + ], +) +def test_image_vision_llm_reader_load_data_wo_parser_config( + torch_installed: bool, test_16x16_png_image_file: str +): + """ + We use doubles (mocks and fakes) for the model and the tokenizer objects + in order to avoid having to download checkpoints as part of tests, while + still covering most of the `ImageVisionLLMReader` class functionality. + """ + with mock.patch( + "transformers.Blip2ForConditionalGeneration.from_pretrained", + return_value=ModelFake(), + ), mock.patch( + "transformers.Blip2Processor.from_pretrained", + return_value=TokenizerFake(), + ): + if torch_installed: + image_vision_llm_reader = ImageVisionLLMReader() + result = image_vision_llm_reader.load_data(file=test_16x16_png_image_file)[ + 0 + ] + assert ( + result.text + == "Question: describe what you see in this image. Answer: a black and white checkered pattern" + ) + else: + with _get_custom_import(torch_installed=False): + with pytest.raises(ImportError) as excinfo: + image_vision_llm_reader = ImageVisionLLMReader() + + assert ( + str(excinfo.value) + == "Please install extra dependencies that are required for the ImageCaptionReader: `pip install torch transformers sentencepiece Pillow`" + )