Skip to content
Snippets Groups Projects
Commit 0e2703c5 authored by Himanshu Shukla's avatar Himanshu Shukla
Browse files

Added complete inferencing functionality of 1. terminal inferencing, 2. gradio...

Added complete inferencing functionality of 1. terminal inferencing, 2. gradio inferencing, 3. checkpoint inferencing in UI/CLI
parent 6b1c0d58
No related branches found
No related tags found
No related merge requests found
...@@ -3,26 +3,46 @@ ...@@ -3,26 +3,46 @@
## Hugging face setup ## Hugging face setup
**Important Note**: Before running the inference, you'll need your Hugging Face access token, which you can get at your Settings page [here](https://huggingface.co/settings/tokens). Then run `huggingface-cli login` and copy and paste your Hugging Face access token to complete the login to make sure the scripts can download Hugging Face models if needed. **Important Note**: Before running the inference, you'll need your Hugging Face access token, which you can get at your Settings page [here](https://huggingface.co/settings/tokens). Then run `huggingface-cli login` and copy and paste your Hugging Face access token to complete the login to make sure the scripts can download Hugging Face models if needed.
## Multimodal Inference ## Multimodal Inference and CLI inference with or without PEFT LoRA weights
For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library.
The way to run this would be: ### Model Overview
``` - Base model: `meta-llama/Llama-3.2-11B-Vision-Instruct`
python multi_modal_infer.py --image_path PATH_TO_IMAGE --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" - Uses PEFT library (v0.13.1) for efficient fine-tuning
``` - Supports vision-language tasks with instruction capabilities
---
## Multi-modal Inferencing Using gradio UI for inferencing
For multi-modal inferencing using gradio UI we have added [multi_modal_infer_gradio_UI.py](multi_modal_infer_gradio_UI.py) which used gradio and transformers library.
### Steps to Run ### Features in
`multi_modal_infer.py`
The way to run this would be: All functionality has been consolidated into a single file with three main modes:
- Ensure having proper access to llama 3.2 vision models, then run the command given below ### Steops to run are given below:
1. **Basic Inference**
```bash
python multi_modal_infer.py \
--image_path "path/to/image.jpg" \
--prompt_text "Describe this image" \
--model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
--hf_token "your_token"
```
2. **Gradio UI Mode**
```bash
python multi_modal_infer.py \
--model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
--hf_token "your_token" \
--gradio_ui
``` ```
python multi_modal_infer_gradio_UI.py --hf_token <your hf_token here>
3. **LoRA Fine-tuning Integration**
```bash
python multi_modal_infer.py \
--image_path "path/to/image.jpg" \
--prompt_text "Describe this image" \
--model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
--hf_token "your_token" \
--finetuning_path "path/to/lora/weights"
``` ```
## Text-only Inference ## Text-only Inference
For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments. For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.
...@@ -114,32 +134,3 @@ python inference.py --model_name <training_config.output_dir> --prompt_file <tes ...@@ -114,32 +134,3 @@ python inference.py --model_name <training_config.output_dir> --prompt_file <tes
## Inference on large models like Meta Llama 405B ## Inference on large models like Meta Llama 405B
The FP8 quantized variants of Meta Llama (i.e. meta-llama/Meta-Llama-3.1-405B-FP8 and meta-llama/Meta-Llama-3.1-405B-Instruct-FP8) can be executed on a single node with 8x80GB H100 using the scripts located in this folder. The FP8 quantized variants of Meta Llama (i.e. meta-llama/Meta-Llama-3.1-405B-FP8 and meta-llama/Meta-Llama-3.1-405B-Instruct-FP8) can be executed on a single node with 8x80GB H100 using the scripts located in this folder.
To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md). To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).
### Inference-with-lora-checkpoints
After fine-tuning the model, you can use the `code-merge-inference.py` script to generate text from images. The script supports merging PEFT adapter weights from a specified path.
#### Usage
To run the inference script, use the following command:
```bash
python code-merge-inference.py \
--image_path "path/to/your/image.png" \
--prompt_text "Your prompt text here" \
--temperature 1 \
--top_p 0.5 \
--model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
--hf_token "your_hugging_face_token" \
--finetuning_path "path/to/your/finetuned/model"
```
#### Script Details
The `code-merge-inference.py` script performs the following steps:
1. **Load Model and Processor**: Loads the pre-trained model and processor, and optionally loads PEFT adapter weights if specified.
2. **Process Image**: Opens and converts the input image.
3. **Generate Text**: Generates text from the image using the model and processor.
For more details, refer to the `code-merge-inference.py` script.
\ No newline at end of file
import argparse
import os
import sys
import torch
from accelerate import Accelerator
from PIL import Image as PIL_Image
from transformers import MllamaForConditionalGeneration, MllamaProcessor
from peft import PeftModel
import gradio as gr
# Initialize accelerator
accelerator = Accelerator()
device = accelerator.device
# Constants
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
MAX_OUTPUT_TOKENS = 2048
MAX_IMAGE_SIZE = (1120, 1120)
def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
"""Load model and processor with optional LoRA adapter"""
model = MllamaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=device,
token=hf_token
)
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
if finetuning_path and os.path.exists(finetuning_path):
print(f"Loading adapter from '{finetuning_path}'...")
model = PeftModel.from_pretrained(
model,
finetuning_path,
is_adapter=True,
torch_dtype=torch.bfloat16
)
print("Adapter merged successfully")
model, processor = accelerator.prepare(model, processor)
return model, processor
def process_image(image_path: str) -> PIL_Image.Image:
"""Process and validate image input"""
if not os.path.exists(image_path):
print(f"Image file '{image_path}' does not exist.")
sys.exit(1)
return PIL_Image.open(image_path).convert("RGB")
def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
"""Generate text from image using model"""
conversation = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(image, prompt, return_tensors="pt").to(device)
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
return processor.decode(output[0])[len(prompt):]
def gradio_interface(model, processor):
"""Create Gradio UI"""
def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
if image is not None:
image = image.resize(MAX_IMAGE_SIZE)
result = generate_text_from_image(model, processor, image, user_prompt, temperature, top_p)
history.append((user_prompt, result))
return history
def clear_chat():
return []
with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
with gr.Column(scale=2):
chat_history = gr.Chatbot(label="Chat", height=512)
user_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", lines=2)
with gr.Row():
generate_button = gr.Button("Generate")
clear_button = gr.Button("Clear")
generate_button.click(
fn=describe_image,
inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
outputs=[chat_history]
)
clear_button.click(fn=clear_chat, outputs=[chat_history])
return demo
def main(args):
"""Main execution flow"""
model, processor = load_model_and_processor(
args.model_name,
args.hf_token,
args.finetuning_path
)
if args.gradio_ui:
demo = gradio_interface(model, processor)
demo.launch()
else:
image = process_image(args.image_path)
result = generate_text_from_image(
model, processor, image, args.prompt_text, args.temperature, args.top_p
)
print("Generated Text:", result)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
parser.add_argument("--image_path", type=str, help="Path to the input image")
parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
args = parser.parse_args()
main(args)
\ No newline at end of file
import gradio as gr
import torch
import os
from PIL import Image
from accelerate import Accelerator
from transformers import MllamaForConditionalGeneration, AutoProcessor
import argparse # Import argparse
# Parse the command line arguments
parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model")
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token")
args = parser.parse_args()
# Hugging Face token
hf_token = args.hf_token
# Initialize Accelerator
accelerate = Accelerator()
device = accelerate.device
# Set memory management for PyTorch
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed
# Model ID
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# Load model with the Hugging Face token
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=device,
use_auth_token=hf_token # Pass the Hugging Face token here
)
# Load the processor
processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token)
# Visual theme
visual_theme = gr.themes.Default() # Default, Soft or Monochrome
# Constants
MAX_OUTPUT_TOKENS = 2048
MAX_IMAGE_SIZE = (1120, 1120)
# Function to process the image and generate a description
def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
# Initialize cleaned_output variable
cleaned_output = ""
if image is not None:
# Resize image if necessary
image = image.resize(MAX_IMAGE_SIZE)
prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
# Preprocess the image and prompt
inputs = processor(image, prompt, return_tensors="pt").to(device)
else:
# Text-only input if no image is provided
prompt = f"<|begin_of_text|>{user_prompt} Answer:"
# Preprocess the prompt only (no image)
inputs = processor(prompt, return_tensors="pt").to(device)
# Generate output with model
output = model.generate(
**inputs,
max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
temperature=temperature,
top_k=top_k,
top_p=top_p
)
# Decode the raw output
raw_output = processor.decode(output[0])
# Clean up the output to remove system tokens
cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
# Ensure the prompt is not repeated in the output
if cleaned_output.startswith(user_prompt):
cleaned_output = cleaned_output[len(user_prompt):].strip()
# Append the new conversation to the history
history.append((user_prompt, cleaned_output))
return history
# Function to clear the chat history
def clear_chat():
return []
# Gradio Interface
def gradio_interface():
with gr.Blocks(visual_theme) as demo:
gr.HTML(
"""
<h1 style='text-align: center'>
meta-llama/Llama-3.2-11B-Vision-Instruct
</h1>
""")
with gr.Row():
# Left column with image and parameter inputs
with gr.Column(scale=1):
image_input = gr.Image(
label="Image",
type="pil",
image_mode="RGB",
height=512, # Set the height
width=512 # Set the width
)
# Parameter sliders
temperature = gr.Slider(
label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True)
top_k = gr.Slider(
label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
top_p = gr.Slider(
label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
max_tokens = gr.Slider(
label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
# Right column with the chat interface
with gr.Column(scale=2):
chat_history = gr.Chatbot(label="Chat", height=512)
# User input box for prompt
user_prompt = gr.Textbox(
show_label=False,
container=False,
placeholder="Enter your prompt",
lines=2
)
# Generate and Clear buttons
with gr.Row():
generate_button = gr.Button("Generate")
clear_button = gr.Button("Clear")
# Define the action for the generate button
generate_button.click(
fn=describe_image,
inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
outputs=[chat_history]
)
# Define the action for the clear button
clear_button.click(
fn=clear_chat,
inputs=[],
outputs=[chat_history]
)
return demo
# Launch the interface
demo = gradio_interface()
# demo.launch(server_name="0.0.0.0", server_port=12003)
demo.launch()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment