From 57afa0b51e3b6290badefd2a2e96a355844a62c3 Mon Sep 17 00:00:00 2001
From: Kai Wu <kaiwu@meta.com>
Date: Tue, 24 Sep 2024 16:02:32 -0700
Subject: [PATCH] use AutoModel

---
 recipes/quickstart/finetuning/finetune_vision_model.md | 2 +-
 src/llama_recipes/finetuning.py                        | 6 +++---
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/recipes/quickstart/finetuning/finetune_vision_model.md b/recipes/quickstart/finetuning/finetune_vision_model.md
index 097d683e..a4f6849c 100644
--- a/recipes/quickstart/finetuning/finetune_vision_model.md
+++ b/recipes/quickstart/finetuning/finetune_vision_model.md
@@ -1,7 +1,7 @@
 ## Fine-Tuning Meta Llama Multi Modal Models recipe
 This recipe steps you through how to finetune a Llama 3.2 vision model on the OCR VQA task using the [OCRVQA](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron/viewer/ocrvqa?row=0) dataset.
 
-**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset to demonstrate the steps needed for fine-tuning our vision models.
+**Disclaimer**: As our vision models already have a very good OCR ability, here we just use the OCRVQA dataset only for demonstration purposes of the required steps for fine-tuning our vision models with llama-recipes.
 
 ### Fine-tuning steps
 
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 5c164694..029b13d5 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -20,9 +20,9 @@ from transformers import (
     AutoConfig,
     AutoTokenizer,
     BitsAndBytesConfig,
-    LlamaForCausalLM,
     AutoProcessor, 
-    MllamaForConditionalGeneration
+    MllamaForConditionalGeneration,
+    AutoModel,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.mllama.modeling_mllama import  MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
@@ -134,7 +134,7 @@ def main(**kwargs):
         processor.tokenizer.padding_side='right'
     elif config.model_type == "llama":
         is_vision = False
-        model = LlamaForCausalLM.from_pretrained(
+        model = AutoModel.from_pretrained(
             train_config.model_name,
             quantization_config=bnb_config,
             use_cache=use_cache,
-- 
GitLab