diff --git a/recipes/quickstart/inference/local_inference/multi_modal_infer.py b/recipes/quickstart/inference/local_inference/multi_modal_infer.py index 8c11de8ee8c4b603fd77228f915957c748ce8b90..27d45b5f13a6f44fe8c7ee34b4b88a97f0f5a2fb 100644 --- a/recipes/quickstart/inference/local_inference/multi_modal_infer.py +++ b/recipes/quickstart/inference/local_inference/multi_modal_infer.py @@ -1,10 +1,11 @@ +import argparse import os import sys -import argparse -from PIL import Image as PIL_Image + import torch +from accelerate import Accelerator +from PIL import Image as PIL_Image from transformers import MllamaForConditionalGeneration, MllamaProcessor -from accelerate import Accelerator accelerator = Accelerator() @@ -14,15 +15,19 @@ device = accelerator.device DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct" -def load_model_and_processor(model_name: str, hf_token: str): +def load_model_and_processor(model_name: str): """ Load the model and processor based on the 11B or 90B model. """ - model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device, - token=hf_token) - processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True) + model = MllamaForConditionalGeneration.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + use_safetensors=True, + device_map=device, + ) + processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True) - model, processor=accelerator.prepare(model, processor) + model, processor = accelerator.prepare(model, processor) return model, processor @@ -37,37 +42,67 @@ def process_image(image_path: str) -> PIL_Image.Image: return PIL_Image.open(f).convert("RGB") -def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float): +def generate_text_from_image( + model, processor, image, prompt_text: str, temperature: float, top_p: float +): """ Generate text from an image using the model and processor. """ conversation = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]} + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": prompt_text}], + } ] - prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + prompt = processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) inputs = processor(image, prompt, return_tensors="pt").to(device) - output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512) - return processor.decode(output[0])[len(prompt):] + output = model.generate( + **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512 + ) + return processor.decode(output[0])[len(prompt) :] -def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str): +def main( + image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str +): """ - Call all the functions. + Call all the functions. """ - model, processor = load_model_and_processor(model_name, hf_token) + model, processor = load_model_and_processor(model_name) image = process_image(image_path) - result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p) + result = generate_text_from_image( + model, processor, image, prompt_text, temperature, top_p + ) print("Generated Text: " + result) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.") + parser = argparse.ArgumentParser( + description="Generate text from an image and prompt using the 3.2 MM Llama model." + ) parser.add_argument("--image_path", type=str, help="Path to the image file") - parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image") - parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)") - parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)") - parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')") - parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication") + parser.add_argument( + "--prompt_text", type=str, help="Prompt text to describe the image" + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Temperature for generation (default: 0.7)", + ) + parser.add_argument( + "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)" + ) + parser.add_argument( + "--model_name", + type=str, + default=DEFAULT_MODEL, + help=f"Model name (default: '{DEFAULT_MODEL}')", + ) args = parser.parse_args() - main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token) \ No newline at end of file + main( + args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name + ) diff --git a/src/llama_recipes/inference/model_utils.py b/src/llama_recipes/inference/model_utils.py index 2b150eea3a5fb87277c9c3e321bed7c92c5b5737..99f191005fc0c6bf7830058cc8a614cc72953862 100644 --- a/src/llama_recipes/inference/model_utils.py +++ b/src/llama_recipes/inference/model_utils.py @@ -1,17 +1,29 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the GNU General Public License version 3. +from warnings import warn + +from llama_recipes.configs import quantization_config as QUANT_CONFIG from llama_recipes.utils.config_utils import update_config -from llama_recipes.configs import quantization_config as QUANT_CONFIG from peft import PeftModel -from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig -from warnings import warn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + LlamaConfig, + LlamaForCausalLM, + MllamaConfig, + MllamaForConditionalGeneration, +) + # Function to load the main model for text generation def load_model(model_name, quantization, use_fast_kernels, **kwargs): if type(quantization) == type(True): - warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning) - quantization = "8bit" + warn( + "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", + FutureWarning, + ) + quantization = "8bit" bnb_config = None if quantization: @@ -23,10 +35,10 @@ def load_model(model_name, quantization, use_fast_kernels, **kwargs): kwargs = {} if bnb_config: - kwargs["quantization_config"]=bnb_config - kwargs["device_map"]="auto" - kwargs["low_cpu_mem_usage"]=True - kwargs["attn_implementation"]="sdpa" if use_fast_kernels else None + kwargs["quantization_config"] = bnb_config + kwargs["device_map"] = "auto" + kwargs["low_cpu_mem_usage"] = True + kwargs["attn_implementation"] = "sdpa" if use_fast_kernels else None model = AutoModelForCausalLM.from_pretrained( model_name, return_dict=True, @@ -40,10 +52,16 @@ def load_peft_model(model, peft_model): peft_model = PeftModel.from_pretrained(model, peft_model) return peft_model + # Loading the model from config to load FSDP checkpoints into that def load_llama_from_config(config_path): - model_config = LlamaConfig.from_pretrained(config_path) - model = LlamaForCausalLM(config=model_config) + config = AutoConfig.from_pretrained(config_path) + if config.model_type == "mllama": + model = MllamaForConditionalGeneration(config=config) + elif config.model_type == "llama": + model = LlamaForCausalLM(config=config) + else: + raise ValueError( + f"Unsupported model type: {config.model_type}, Please use llama or mllama model." + ) return model - - \ No newline at end of file