diff --git a/recipes/quickstart/inference/local_inference/README.md b/recipes/quickstart/inference/local_inference/README.md index d50dd806dbe7432d48703043ab7e1d9d592c31e1..0bf2ad9d792c97d1505f2da048642a0bf8cd3954 100644 --- a/recipes/quickstart/inference/local_inference/README.md +++ b/recipes/quickstart/inference/local_inference/README.md @@ -10,6 +10,18 @@ The way to run this would be: ``` python multi_modal_infer.py --image_path PATH_TO_IMAGE --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" ``` +--- +## Multi-modal Inferencing Using gradio UI for inferencing +For multi-modal inferencing using gradio UI we have added [multi_modal_infer_gradio_UI.py](multi_modal_infer_gradio_UI.py) which used gradio and transformers library. + +### Steps to Run + +The way to run this would be: +- Ensure having proper access to llama 3.2 vision models, then run the command given below + +``` +python multi_modal_infer_gradio_UI.py --hf_token <your hf_token here> +``` ## Text-only Inference For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments. diff --git a/recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py b/recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py new file mode 100644 index 0000000000000000000000000000000000000000..5119ac7c30c26f36f732db2025dfe0423c3d6971 --- /dev/null +++ b/recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py @@ -0,0 +1,157 @@ +import gradio as gr +import torch +import os +from PIL import Image +from accelerate import Accelerator +from transformers import MllamaForConditionalGeneration, AutoProcessor +import argparse # Import argparse + +# Parse the command line arguments +parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model") +parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token") +args = parser.parse_args() + +# Hugging Face token +hf_token = args.hf_token + +# Initialize Accelerator +accelerate = Accelerator() +device = accelerate.device + +# Set memory management for PyTorch +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed + +# Model ID +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" + +# Load model with the Hugging Face token +model = MllamaForConditionalGeneration.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map=device, + use_auth_token=hf_token # Pass the Hugging Face token here +) + +# Load the processor +processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token) + +# Visual theme +visual_theme = gr.themes.Default() # Default, Soft or Monochrome + +# Constants +MAX_OUTPUT_TOKENS = 2048 +MAX_IMAGE_SIZE = (1120, 1120) + +# Function to process the image and generate a description +def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history): + # Initialize cleaned_output variable + cleaned_output = "" + + if image is not None: + # Resize image if necessary + image = image.resize(MAX_IMAGE_SIZE) + prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:" + # Preprocess the image and prompt + inputs = processor(image, prompt, return_tensors="pt").to(device) + else: + # Text-only input if no image is provided + prompt = f"<|begin_of_text|>{user_prompt} Answer:" + # Preprocess the prompt only (no image) + inputs = processor(prompt, return_tensors="pt").to(device) + + # Generate output with model + output = model.generate( + **inputs, + max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS), + temperature=temperature, + top_k=top_k, + top_p=top_p + ) + + # Decode the raw output + raw_output = processor.decode(output[0]) + + # Clean up the output to remove system tokens + cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "") + + # Ensure the prompt is not repeated in the output + if cleaned_output.startswith(user_prompt): + cleaned_output = cleaned_output[len(user_prompt):].strip() + + # Append the new conversation to the history + history.append((user_prompt, cleaned_output)) + + return history + + +# Function to clear the chat history +def clear_chat(): + return [] + +# Gradio Interface +def gradio_interface(): + with gr.Blocks(visual_theme) as demo: + gr.HTML( + """ + <h1 style='text-align: center'> + meta-llama/Llama-3.2-11B-Vision-Instruct + </h1> + """) + with gr.Row(): + # Left column with image and parameter inputs + with gr.Column(scale=1): + image_input = gr.Image( + label="Image", + type="pil", + image_mode="RGB", + height=512, # Set the height + width=512 # Set the width + ) + + # Parameter sliders + temperature = gr.Slider( + label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True) + top_k = gr.Slider( + label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True) + top_p = gr.Slider( + label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True) + max_tokens = gr.Slider( + label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True) + + # Right column with the chat interface + with gr.Column(scale=2): + chat_history = gr.Chatbot(label="Chat", height=512) + + # User input box for prompt + user_prompt = gr.Textbox( + show_label=False, + container=False, + placeholder="Enter your prompt", + lines=2 + ) + + # Generate and Clear buttons + with gr.Row(): + generate_button = gr.Button("Generate") + clear_button = gr.Button("Clear") + + # Define the action for the generate button + generate_button.click( + fn=describe_image, + inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history], + outputs=[chat_history] + ) + + # Define the action for the clear button + clear_button.click( + fn=clear_chat, + inputs=[], + outputs=[chat_history] + ) + + return demo + +# Launch the interface +demo = gradio_interface() +# demo.launch(server_name="0.0.0.0", server_port=12003) +demo.launch() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 67c5afa3e9a1cc2007e6cac613f4230e1a8f1d96..a5d218b8aa40ae75ebf9e49ca27dd73d91788cb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,5 @@ faiss-gpu; python_version < '3.11' unstructured[pdf] sentence_transformers codeshield +gradio +markupsafe==2.0.1