Skip to content
Snippets Groups Projects
Commit d1195a6f authored by JimChienTW's avatar JimChienTW
Browse files

Fix model parameter mismatch by printing parameters before FSDP

parent f228cb4d
No related branches found
No related tags found
No related merge requests found
...@@ -237,7 +237,8 @@ def main(**kwargs): ...@@ -237,7 +237,8 @@ def main(**kwargs):
if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama": if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
freeze_LLM_only(model) freeze_LLM_only(model)
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
...@@ -306,8 +307,6 @@ def main(**kwargs): ...@@ -306,8 +307,6 @@ def main(**kwargs):
dataset_processer = processor dataset_processer = processor
else: else:
dataset_processer = tokenizer dataset_processer = tokenizer
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
# Load and preprocess the dataset for training and validation # Load and preprocess the dataset for training and validation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment