Skip to content
Snippets Groups Projects
Commit 6b6bb37e authored by Himanshu Shukla's avatar Himanshu Shukla
Browse files

fixed gradio UI during performing the tests, it is working in this commit

parent 20dd4740
No related branches found
No related tags found
No related merge requests found
import argparse
import os
import sys
import argparse
from PIL import Image as PIL_Image
import torch
from transformers import MllamaForConditionalGeneration, MllamaProcessor
from accelerate import Accelerator
from peft import PeftModel # Make sure to install the `peft` library
from PIL import Image as PIL_Image
from transformers import MllamaForConditionalGeneration, MllamaProcessor
from peft import PeftModel
import gradio as gr
# Initialize accelerator
accelerator = Accelerator()
device = accelerator.device
# Constants
DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
MAX_OUTPUT_TOKENS = 2048
MAX_IMAGE_SIZE = (1120, 1120)
def load_model_and_processor(model_name: str, hf_token: str, finetuning_path: str = None):
"""
Load the model and processor, and optionally load adapter weights if specified
"""
# Load pre-trained model and processor
def load_model_and_processor(model_name: str, hf_token: str = None, finetuning_path: str = None):
"""Load model and processor with optional LoRA adapter"""
print(f"Loading model: {model_name}")
model = MllamaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
use_safetensors=True,
model_name,
torch_dtype=torch.bfloat16,
use_safetensors=True,
device_map=device,
token=hf_token
)
processor = MllamaProcessor.from_pretrained(
model_name,
token=hf_token,
use_safetensors=True
)
processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
# If a finetuning path is provided, load the adapter model
if finetuning_path and os.path.exists(finetuning_path):
adapter_weights_path = os.path.join(finetuning_path, "adapter_model.safetensors")
adapter_config_path = os.path.join(finetuning_path, "adapter_config.json")
if os.path.exists(adapter_weights_path) and os.path.exists(adapter_config_path):
print(f"Loading adapter from '{finetuning_path}'...")
# Load the model with adapters using `peft`
model = PeftModel.from_pretrained(
model,
finetuning_path, # This should be the folder containing the adapter files
is_adapter=True,
torch_dtype=torch.bfloat16
)
print("Adapter merged successfully with the pre-trained model.")
else:
print(f"Adapter files not found in '{finetuning_path}'. Using pre-trained model only.")
else:
print(f"No fine-tuned weights or adapters found in '{finetuning_path}'. Using pre-trained model only.")
# Prepare the model and processor for accelerated training
model, processor = accelerator.prepare(model, processor)
print(f"Loading LoRA adapter from '{finetuning_path}'...")
model = PeftModel.from_pretrained(
model,
finetuning_path,
is_adapter=True,
torch_dtype=torch.bfloat16
)
print("LoRA adapter merged successfully")
model, processor = accelerator.prepare(model, processor)
return model, processor
def process_image(image_path: str) -> PIL_Image.Image:
"""
Open and convert an image from the specified path.
"""
if not os.path.exists(image_path):
print(f"The image file '{image_path}' does not exist.")
sys.exit(1)
with open(image_path, "rb") as f:
return PIL_Image.open(f).convert("RGB")
def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
"""Process and validate image input"""
if image is not None:
return image.convert("RGB")
if image_path and os.path.exists(image_path):
return PIL_Image.open(image_path).convert("RGB")
raise ValueError("No valid image provided")
def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
"""
Generate text from an image using the model and processor.
"""
"""Generate text from image using model"""
conversation = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(image, prompt, return_tensors="pt").to(device)
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=2048)
output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
return processor.decode(output[0])[len(prompt):]
def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str, finetuning_path: str = None):
"""
Call all the functions and optionally merge adapter weights from a specified path.
"""
model, processor = load_model_and_processor(model_name, hf_token, finetuning_path)
image = process_image(image_path)
result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
print("Generated Text: " + result)
def gradio_interface(model_name: str, hf_token: str):
"""Create Gradio UI with LoRA support"""
# Initialize model state
current_model = {"model": None, "processor": None}
def load_or_reload_model(enable_lora: bool, lora_path: str = None):
current_model["model"], current_model["processor"] = load_model_and_processor(
model_name,
hf_token,
lora_path if enable_lora else None
)
return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
if image is not None:
try:
processed_image = process_image(image=image)
result = generate_text_from_image(
current_model["model"],
current_model["processor"],
processed_image,
user_prompt,
temperature,
top_p
)
history.append((user_prompt, result))
except Exception as e:
history.append((user_prompt, f"Error: {str(e)}"))
return history
def clear_chat():
return []
with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
with gr.Row():
with gr.Column(scale=1):
# Model loading controls
with gr.Group():
enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
lora_path = gr.Textbox(
label="LoRA Weights Path",
placeholder="Path to LoRA weights folder",
visible=False
)
load_status = gr.Textbox(label="Load Status", interactive=False)
load_button = gr.Button("Load/Reload Model")
# Image and parameter controls
image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
with gr.Column(scale=2):
chat_history = gr.Chatbot(label="Chat", height=512)
user_prompt = gr.Textbox(
show_label=False,
placeholder="Enter your prompt",
lines=2
)
with gr.Row():
generate_button = gr.Button("Generate")
clear_button = gr.Button("Clear")
# Event handlers
enable_lora.change(
fn=lambda x: gr.update(visible=x),
inputs=[enable_lora],
outputs=[lora_path]
)
load_button.click(
fn=load_or_reload_model,
inputs=[enable_lora, lora_path],
outputs=[load_status]
)
generate_button.click(
fn=describe_image,
inputs=[
image_input, user_prompt, temperature,
top_k, top_p, max_tokens, chat_history
],
outputs=[chat_history]
)
clear_button.click(fn=clear_chat, outputs=[chat_history])
# Initial model load
load_or_reload_model(False)
return demo
def main(args):
"""Main execution flow"""
if args.gradio_ui:
demo = gradio_interface(args.model_name, args.hf_token)
demo.launch()
else:
model, processor = load_model_and_processor(
args.model_name,
args.hf_token,
args.finetuning_path
)
image = process_image(image_path=args.image_path)
result = generate_text_from_image(
model, processor, image,
args.prompt_text,
args.temperature,
args.top_p
)
print("Generated Text:", result)
if __name__ == "__main__":
# Example usage with argparse (optional)
parser = argparse.ArgumentParser(description="Generate text from an image using a fine-tuned model with adapters.")
parser.add_argument("--image_path", type=str, required=True, help="Path to the input image.")
parser.add_argument("--prompt_text", type=str, required=True, help="Prompt text for the image.")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling.")
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Pre-trained model name.")
parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face API token.")
parser.add_argument("--finetuning_path", type=str, help="Path to the fine-tuning weights (adapters).")
parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
parser.add_argument("--image_path", type=str, help="Path to the input image")
parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
parser.add_argument("--hf_token", type=str, help="Hugging Face API token")
parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
args = parser.parse_args()
main(
image_path=args.image_path,
prompt_text=args.prompt_text,
temperature=args.temperature,
top_p=args.top_p,
model_name=args.model_name,
hf_token=args.hf_token,
finetuning_path=args.finetuning_path
)
main(args)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment