diff --git a/recipes/quickstart/inference/local_inference/multi_modal_infer.py b/recipes/quickstart/inference/local_inference/multi_modal_infer.py index 5459f2ced28e422aed55d6679f3dd04827c45d0b..a92482c3c884f5392c08926e973e84347cf557e4 100644 --- a/recipes/quickstart/inference/local_inference/multi_modal_infer.py +++ b/recipes/quickstart/inference/local_inference/multi_modal_infer.py @@ -1,117 +1,191 @@ +import argparse import os import sys -import argparse -from PIL import Image as PIL_Image import torch -from transformers import MllamaForConditionalGeneration, MllamaProcessor from accelerate import Accelerator -from peft import PeftModel # Make sure to install the `peft` library +from PIL import Image as PIL_Image +from transformers import MllamaForConditionalGeneration, MllamaProcessor +from peft import PeftModel +import gradio as gr +# Initialize accelerator accelerator = Accelerator() device = accelerator.device # Constants DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct" +MAX_OUTPUT_TOKENS = 2048 +MAX_IMAGE_SIZE = (1120, 1120) - -def load_model_and_processor(model_name: str, hf_token: str, finetuning_path: str = None): - """ - Load the model and processor, and optionally load adapter weights if specified - """ - # Load pre-trained model and processor +def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None): + """Load model and processor with optional LoRA adapter""" + print(f"Loading model: {model_name}") model = MllamaForConditionalGeneration.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - use_safetensors=True, + model_name, + torch_dtype=torch.bfloat16, + use_safetensors=True, device_map=device, token=hf_token ) - processor = MllamaProcessor.from_pretrained( - model_name, - token=hf_token, - use_safetensors=True - ) + processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True) - # If a finetuning path is provided, load the adapter model if finetuning_path and os.path.exists(finetuning_path): - adapter_weights_path = os.path.join(finetuning_path, "adapter_model.safetensors") - adapter_config_path = os.path.join(finetuning_path, "adapter_config.json") - - if os.path.exists(adapter_weights_path) and os.path.exists(adapter_config_path): - print(f"Loading adapter from '{finetuning_path}'...") - # Load the model with adapters using `peft` - model = PeftModel.from_pretrained( - model, - finetuning_path, # This should be the folder containing the adapter files - is_adapter=True, - torch_dtype=torch.bfloat16 - ) - - print("Adapter merged successfully with the pre-trained model.") - else: - print(f"Adapter files not found in '{finetuning_path}'. Using pre-trained model only.") - else: - print(f"No fine-tuned weights or adapters found in '{finetuning_path}'. Using pre-trained model only.") - - # Prepare the model and processor for accelerated training - model, processor = accelerator.prepare(model, processor) + print(f"Loading LoRA adapter from '{finetuning_path}'...") + model = PeftModel.from_pretrained( + model, + finetuning_path, + is_adapter=True, + torch_dtype=torch.bfloat16 + ) + print("LoRA adapter merged successfully") + model, processor = accelerator.prepare(model, processor) return model, processor - -def process_image(image_path: str) -> PIL_Image.Image: - """ - Open and convert an image from the specified path. - """ - if not os.path.exists(image_path): - print(f"The image file '{image_path}' does not exist.") - sys.exit(1) - with open(image_path, "rb") as f: - return PIL_Image.open(f).convert("RGB") - +def process_image(image_path: str = None, image = None) -> PIL_Image.Image: + """Process and validate image input""" + if image is not None: + return image.convert("RGB") + if image_path and os.path.exists(image_path): + return PIL_Image.open(image_path).convert("RGB") + raise ValueError("No valid image provided") def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float): - """ - Generate text from an image using the model and processor. - """ + """Generate text from image using model""" conversation = [ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]} ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) inputs = processor(image, prompt, return_tensors="pt").to(device) - output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=2048) + output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS) return processor.decode(output[0])[len(prompt):] - -def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str, finetuning_path: str = None): - """ - Call all the functions and optionally merge adapter weights from a specified path. - """ - model, processor = load_model_and_processor(model_name, hf_token, finetuning_path) - image = process_image(image_path) - result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p) - print("Generated Text: " + result) - +def gradio_interface(model_name: str, hf_token: str): + """Create Gradio UI with LoRA support""" + # Initialize model state + current_model = {"model": None, "processor": None} + + def load_or_reload_model(enable_lora: bool, lora_path: str = None): + current_model["model"], current_model["processor"] = load_model_and_processor( + model_name, + hf_token, + lora_path if enable_lora else None + ) + return "Model loaded successfully" + (" with LoRA" if enable_lora else "") + + def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history): + if image is not None: + try: + processed_image = process_image(image=image) + result = generate_text_from_image( + current_model["model"], + current_model["processor"], + processed_image, + user_prompt, + temperature, + top_p + ) + history.append((user_prompt, result)) + except Exception as e: + history.append((user_prompt, f"Error: {str(e)}")) + return history + + def clear_chat(): + return [] + + with gr.Blocks() as demo: + gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>") + + with gr.Row(): + with gr.Column(scale=1): + # Model loading controls + with gr.Group(): + enable_lora = gr.Checkbox(label="Enable LoRA", value=False) + lora_path = gr.Textbox( + label="LoRA Weights Path", + placeholder="Path to LoRA weights folder", + visible=False + ) + load_status = gr.Textbox(label="Load Status", interactive=False) + load_button = gr.Button("Load/Reload Model") + + # Image and parameter controls + image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512) + temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1) + top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1) + top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1) + max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50) + + with gr.Column(scale=2): + chat_history = gr.Chatbot(label="Chat", height=512) + user_prompt = gr.Textbox( + show_label=False, + placeholder="Enter your prompt", + lines=2 + ) + + with gr.Row(): + generate_button = gr.Button("Generate") + clear_button = gr.Button("Clear") + + # Event handlers + enable_lora.change( + fn=lambda x: gr.update(visible=x), + inputs=[enable_lora], + outputs=[lora_path] + ) + + load_button.click( + fn=load_or_reload_model, + inputs=[enable_lora, lora_path], + outputs=[load_status] + ) + + generate_button.click( + fn=describe_image, + inputs=[ + image_input, user_prompt, temperature, + top_k, top_p, max_tokens, chat_history + ], + outputs=[chat_history] + ) + + clear_button.click(fn=clear_chat, outputs=[chat_history]) + + # Initial model load + load_or_reload_model(False) + return demo + +def main(args): + """Main execution flow""" + if args.gradio_ui: + demo = gradio_interface(args.model_name, args.hf_token) + demo.launch() + else: + model, processor = load_model_and_processor( + args.model_name, + args.hf_token, + args.finetuning_path + ) + image = process_image(image_path=args.image_path) + result = generate_text_from_image( + model, processor, image, + args.prompt_text, + args.temperature, + args.top_p + ) + print("Generated Text:", result) if __name__ == "__main__": - # Example usage with argparse (optional) - parser = argparse.ArgumentParser(description="Generate text from an image using a fine-tuned model with adapters.") - parser.add_argument("--image_path", type=str, required=True, help="Path to the input image.") - parser.add_argument("--prompt_text", type=str, required=True, help="Prompt text for the image.") - parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.") - parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling.") - parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Pre-trained model name.") - parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face API token.") - parser.add_argument("--finetuning_path", type=str, help="Path to the fine-tuning weights (adapters).") + parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support") + parser.add_argument("--image_path", type=str, help="Path to the input image") + parser.add_argument("--prompt_text", type=str, help="Prompt text for the image") + parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") + parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling") + parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name") + parser.add_argument("--hf_token", type=str, help="Hugging Face API token") + parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights") + parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI") args = parser.parse_args() - - main( - image_path=args.image_path, - prompt_text=args.prompt_text, - temperature=args.temperature, - top_p=args.top_p, - model_name=args.model_name, - hf_token=args.hf_token, - finetuning_path=args.finetuning_path - ) + main(args) \ No newline at end of file