From d1195a6fd88724ff6492099e7dec915736613f7c Mon Sep 17 00:00:00 2001
From: JimChienTW <jim87112729@gmail.com>
Date: Tue, 19 Nov 2024 14:31:37 +0800
Subject: [PATCH] Fix model parameter mismatch by printing parameters before
 FSDP

---
 src/llama_recipes/finetuning.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index e1d702e2..2a86234c 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -237,7 +237,8 @@ def main(**kwargs):
             
         if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
             freeze_LLM_only(model)
-            
+        
+        print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
@@ -306,8 +307,6 @@ def main(**kwargs):
         dataset_processer = processor
     else:
         dataset_processer = tokenizer
-
-    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     
     # Load and preprocess the dataset for training and validation
 
-- 
GitLab