diff --git a/recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py b/recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py index 19ce2262b89effaf3ae561bb5d73b7508ef37194..f5948e151d9669ee0c89b7f30adcaca3af96d064 100644 --- a/recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py +++ b/recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py @@ -3,46 +3,60 @@ import copy -from datasets import load_dataset import itertools + import torch +from datasets import load_dataset + # check system prompt token seq or user prompt token seq is in the current token list -def check_header(targets,seq): - for i in range(len(seq)-3): - if seq[i:i+3] in targets: +def check_header(targets, seq): + for i in range(len(seq) - 3): + if seq[i : i + 3] in targets: return True return False -def replace_target(target,seq): - for i in range(len(seq)-3): - if seq[i:i+3] == target: - seq[i],seq[i+1],seq[i+2] = -100,-100,-100 + + +def replace_target(target, seq): + for i in range(len(seq) - 3): + if seq[i : i + 3] == target: + seq[i], seq[i + 1], seq[i + 2] = -100, -100, -100 return seq + + def tokenize_dialogs(dialogs, images, processor): text_prompt = processor.apply_chat_template(dialogs) - batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt") + text_prompt = [prompt.replace('<|begin_of_text|>','') for prompt in text_prompt] + batch = processor( + images=images, + text=text_prompt, + padding=True, + return_tensors="pt", + ) label_list = [] for i in range(len(batch["input_ids"])): dialog_tokens = batch["input_ids"][i].tolist() labels = copy.copy(dialog_tokens) - eot_indices = [i for i,n in enumerate(labels) if n == 128009] + eot_indices = [i for i, n in enumerate(labels) if n == 128009] last_idx = 0 # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007] # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007] - prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]] + prompt_header_seqs = [[128006, 9125, 128007], [128006, 882, 128007]] for n, idx in enumerate(eot_indices): - current_seq = labels[last_idx:idx+1] - if check_header(prompt_header_seqs,current_seq): + current_seq = labels[last_idx : idx + 1] + if check_header(prompt_header_seqs, current_seq): # found prompt header, indicating that this seq should be masked - labels[last_idx:idx+1] = [-100] * (idx-last_idx+1) + labels[last_idx : idx + 1] = [-100] * (idx - last_idx + 1) else: - last_idx = idx+1 + last_idx = idx + 1 # Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007] assistant_header_seq = [128006, 78191, 128007] - labels = replace_target(assistant_header_seq,labels) - # Mask the padding token and image token 128256 + labels = replace_target(assistant_header_seq, labels) + # Mask the padding token and image token 128256 for i in range(len(labels)): - if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: # 128256 is image token index + if ( + labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256 + ): # 128256 is image token index labels[i] = -100 label_list.append(labels) batch["labels"] = torch.tensor(label_list) @@ -52,39 +66,75 @@ def tokenize_dialogs(dialogs, images, processor): def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9): # load_dataset will return DatasetDict that contains all the data in the train set dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa") - dataset = dataset_dict['train'] + dataset = dataset_dict["train"] # Comment out the following line to use the full dataset, for quick testing only use 2000 samples dataset = dataset.select(range(2000)) - dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split] + dataset = dataset.train_test_split( + test_size=1 - split_ratio, shuffle=True, seed=42 + )[split] return dataset + class OCRVQADataCollator: def __init__(self, processor): self.processor = processor - self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right + self.processor.tokenizer.padding_side = ( + "right" # during training, one always uses padding on the right + ) + def __call__(self, samples): - dialogs,images = [],[] + dialogs, images = [], [] for sample in samples: - image_list,sample_list = sample["images"],sample["texts"] + image_list, sample_list = sample["images"], sample["texts"] if len(image_list) > 1: raise ValueError("Only support one image per sample") - image = image_list[0].convert("RGB") # only use the first image + image = image_list[0].convert("RGB") # only use the first image dialog = [] for sample_dict in sample_list: if not dialog: # only append image to the first sentence dialog += [ - {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]}, - {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]} - ] - + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": sample_dict["user"].strip()}, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": sample_dict["assistant"].strip(), + } + ], + }, + ] + else: dialog += [ - {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]}, - {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]} - ] + { + "role": "user", + "content": [ + {"type": "text", "text": sample_dict["user"].strip()} + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": sample_dict["assistant"].strip(), + } + ], + }, + ] dialogs.append(dialog) images.append([image]) - return tokenize_dialogs(dialogs,images, self.processor) + return tokenize_dialogs(dialogs, images, self.processor) + + def get_data_collator(processor): return OCRVQADataCollator(processor) + diff --git a/recipes/quickstart/inference/local_inference/multi_modal_infer.py b/recipes/quickstart/inference/local_inference/multi_modal_infer.py index 071dc868345acb642cb4d1bb19707ae5e08b8683..a7f9089c3a47f6c66fe9e24bf5a0c6e06db77ed4 100644 --- a/recipes/quickstart/inference/local_inference/multi_modal_infer.py +++ b/recipes/quickstart/inference/local_inference/multi_modal_infer.py @@ -1,13 +1,15 @@ import argparse import os import sys + +import gradio as gr import torch from accelerate import Accelerator +from huggingface_hub import HfFolder +from peft import PeftModel from PIL import Image as PIL_Image from transformers import MllamaForConditionalGeneration, MllamaProcessor -from peft import PeftModel -import gradio as gr -from huggingface_hub import HfFolder + # Initialize accelerator accelerator = Accelerator() device = accelerator.device @@ -43,24 +45,24 @@ def load_model_and_processor(model_name: str, finetuning_path: str = None): torch_dtype=torch.bfloat16, use_safetensors=True, device_map=device, - token=hf_token + token=hf_token, + ) + processor = MllamaProcessor.from_pretrained( + model_name, token=hf_token, use_safetensors=True ) - processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True) if finetuning_path and os.path.exists(finetuning_path): print(f"Loading LoRA adapter from '{finetuning_path}'...") model = PeftModel.from_pretrained( - model, - finetuning_path, - is_adapter=True, - torch_dtype=torch.bfloat16 + model, finetuning_path, is_adapter=True, torch_dtype=torch.bfloat16 ) print("LoRA adapter merged successfully") - + model, processor = accelerator.prepare(model, processor) return model, processor -def process_image(image_path: str = None, image = None) -> PIL_Image.Image: + +def process_image(image_path: str = None, image=None) -> PIL_Image.Image: """Process and validate image input""" if image is not None: return image.convert("RGB") @@ -68,29 +70,44 @@ def process_image(image_path: str = None, image = None) -> PIL_Image.Image: return PIL_Image.open(image_path).convert("RGB") raise ValueError("No valid image provided") -def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float): + +def generate_text_from_image( + model, processor, image, prompt_text: str, temperature: float, top_p: float +): """Generate text from image using model""" conversation = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]} + { + "role": "user", + "content": [{"type": "image"}, {"type": "text", "text": prompt_text}], + } ] - prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) - inputs = processor(image, prompt, return_tensors="pt").to(device) - output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS) - return processor.decode(output[0])[len(prompt):] + prompt = processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) + inputs = processor( + image, prompt, text_kwargs={"add_special_tokens": False}, return_tensors="pt" + ).to(device) + print("Input Prompt:\n", processor.tokenizer.decode(inputs.input_ids[0])) + output = model.generate( + **inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS + ) + return processor.decode(output[0])[len(prompt) :] + def gradio_interface(model_name: str): """Create Gradio UI with LoRA support""" # Initialize model state current_model = {"model": None, "processor": None} - + def load_or_reload_model(enable_lora: bool, lora_path: str = None): current_model["model"], current_model["processor"] = load_model_and_processor( - model_name, - lora_path if enable_lora else None + model_name, lora_path if enable_lora else None ) return "Model loaded successfully" + (" with LoRA" if enable_lora else "") - def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history): + def describe_image( + image, user_prompt, temperature, top_k, top_p, max_tokens, history + ): if image is not None: try: processed_image = process_image(image=image) @@ -100,7 +117,7 @@ def gradio_interface(model_name: str): processed_image, user_prompt, temperature, - top_p + top_p, ) history.append((user_prompt, result)) except Exception as e: @@ -112,7 +129,7 @@ def gradio_interface(model_name: str): with gr.Blocks() as demo: gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>") - + with gr.Row(): with gr.Column(scale=1): # Model loading controls @@ -121,58 +138,74 @@ def gradio_interface(model_name: str): lora_path = gr.Textbox( label="LoRA Weights Path", placeholder="Path to LoRA weights folder", - visible=False + visible=False, ) load_status = gr.Textbox(label="Load Status", interactive=False) load_button = gr.Button("Load/Reload Model") # Image and parameter controls - image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512) - temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1) - top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1) - top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1) - max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50) + image_input = gr.Image( + label="Image", type="pil", image_mode="RGB", height=512, width=512 + ) + temperature = gr.Slider( + label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1 + ) + top_k = gr.Slider( + label="Top-k", minimum=1, maximum=100, value=50, step=1 + ) + top_p = gr.Slider( + label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1 + ) + max_tokens = gr.Slider( + label="Max Tokens", + minimum=50, + maximum=MAX_OUTPUT_TOKENS, + value=100, + step=50, + ) with gr.Column(scale=2): chat_history = gr.Chatbot(label="Chat", height=512) user_prompt = gr.Textbox( - show_label=False, - placeholder="Enter your prompt", - lines=2 + show_label=False, placeholder="Enter your prompt", lines=2 ) - + with gr.Row(): generate_button = gr.Button("Generate") clear_button = gr.Button("Clear") # Event handlers enable_lora.change( - fn=lambda x: gr.update(visible=x), - inputs=[enable_lora], - outputs=[lora_path] + fn=lambda x: gr.update(visible=x), inputs=[enable_lora], outputs=[lora_path] ) - + load_button.click( fn=load_or_reload_model, inputs=[enable_lora, lora_path], - outputs=[load_status] + outputs=[load_status], ) generate_button.click( fn=describe_image, inputs=[ - image_input, user_prompt, temperature, - top_k, top_p, max_tokens, chat_history + image_input, + user_prompt, + temperature, + top_k, + top_p, + max_tokens, + chat_history, ], - outputs=[chat_history] + outputs=[chat_history], ) - + clear_button.click(fn=clear_chat, outputs=[chat_history]) # Initial model load load_or_reload_model(False) return demo + def main(args): """Main execution flow""" if args.gradio_ui: @@ -180,27 +213,30 @@ def main(args): demo.launch() else: model, processor = load_model_and_processor( - args.model_name, - args.finetuning_path + args.model_name, args.finetuning_path ) image = process_image(image_path=args.image_path) result = generate_text_from_image( - model, processor, image, - args.prompt_text, - args.temperature, - args.top_p + model, processor, image, args.prompt_text, args.temperature, args.top_p ) print("Generated Text:", result) + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support") + parser = argparse.ArgumentParser( + description="Multi-modal inference with optional Gradio UI and LoRA support" + ) parser.add_argument("--image_path", type=str, help="Path to the input image") parser.add_argument("--prompt_text", type=str, help="Prompt text for the image") - parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") + parser.add_argument( + "--temperature", type=float, default=0.7, help="Sampling temperature" + ) parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling") - parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name") + parser.add_argument( + "--model_name", type=str, default=DEFAULT_MODEL, help="Model name" + ) parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights") parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI") - + args = parser.parse_args() main(args)