From 57afa0b51e3b6290badefd2a2e96a355844a62c3 Mon Sep 17 00:00:00 2001 From: Kai Wu <kaiwu@meta.com> Date: Tue, 24 Sep 2024 16:02:32 -0700 Subject: [PATCH] use AutoModel --- recipes/quickstart/finetuning/finetune_vision_model.md | 2 +- src/llama_recipes/finetuning.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/recipes/quickstart/finetuning/finetune_vision_model.md b/recipes/quickstart/finetuning/finetune_vision_model.md index 097d683e..a4f6849c 100644 --- a/recipes/quickstart/finetuning/finetune_vision_model.md +++ b/recipes/quickstart/finetuning/finetune_vision_model.md @@ -1,7 +1,7 @@ ## Fine-Tuning Meta Llama Multi Modal Models recipe This recipe steps you through how to finetune a Llama 3.2 vision model on the OCR VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset. -**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset to demonstrate the steps needed for fine-tuning our vision models. +**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset only for demonstration purposes of the required steps for fine-tuning our vision models with llama-recipes. ### Fine-tuning steps diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 5c164694..029b13d5 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -20,9 +20,9 @@ from transformers import ( AutoConfig, AutoTokenizer, BitsAndBytesConfig, - LlamaForCausalLM, AutoProcessor, - MllamaForConditionalGeneration + MllamaForConditionalGeneration, + AutoModel, ) from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer @@ -134,7 +134,7 @@ def main(**kwargs): processor.tokenizer.padding_side='right' elif config.model_type == "llama": is_vision = False - model = LlamaForCausalLM.from_pretrained( + model = AutoModel.from_pretrained( train_config.model_name, quantization_config=bnb_config, use_cache=use_cache, -- GitLab