Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Llama Recipes
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mirrored_repos
MachineLearning
meta-llama
Llama Recipes
Commits
21e8368c
Commit
21e8368c
authored
5 months ago
by
JimChienTW
Browse files
Options
Downloads
Patches
Plain Diff
add freeze_LLM_only option for mllama finetuning
parent
b9fc1069
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/llama_recipes/configs/training.py
+1
-0
1 addition, 0 deletions
src/llama_recipes/configs/training.py
src/llama_recipes/finetuning.py
+13
-2
13 additions, 2 deletions
src/llama_recipes/finetuning.py
src/llama_recipes/utils/train_utils.py
+11
-1
11 additions, 1 deletion
src/llama_recipes/utils/train_utils.py
with
25 additions
and
3 deletions
src/llama_recipes/configs/training.py
+
1
−
0
View file @
21e8368c
...
@@ -35,6 +35,7 @@ class train_config:
...
@@ -35,6 +35,7 @@ class train_config:
output_dir
:
str
=
"
PATH/to/save/PEFT/model
"
output_dir
:
str
=
"
PATH/to/save/PEFT/model
"
freeze_layers
:
bool
=
False
freeze_layers
:
bool
=
False
num_freeze_layers
:
int
=
1
num_freeze_layers
:
int
=
1
freeze_LLM_only
:
bool
=
False
# Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
quantization
:
str
=
None
quantization
:
str
=
None
one_gpu
:
bool
=
False
one_gpu
:
bool
=
False
save_model
:
bool
=
True
save_model
:
bool
=
True
...
...
This diff is collapsed.
Click to expand it.
src/llama_recipes/finetuning.py
+
13
−
2
View file @
21e8368c
...
@@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
...
@@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
from
llama_recipes.utils.train_utils
import
(
from
llama_recipes.utils.train_utils
import
(
clear_gpu_cache
,
clear_gpu_cache
,
freeze_transformer_layers
,
freeze_transformer_layers
,
freeze_LLM_only
,
get_policies
,
get_policies
,
print_model_size
,
print_model_size
,
setup
,
setup
,
...
@@ -193,8 +194,6 @@ def main(**kwargs):
...
@@ -193,8 +194,6 @@ def main(**kwargs):
)
)
model
.
resize_token_embeddings
(
len
(
tokenizer
))
model
.
resize_token_embeddings
(
len
(
tokenizer
))
print_model_size
(
model
,
train_config
,
rank
if
train_config
.
enable_fsdp
else
0
)
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if
(
if
(
train_config
.
enable_fsdp
train_config
.
enable_fsdp
...
@@ -235,6 +234,10 @@ def main(**kwargs):
...
@@ -235,6 +234,10 @@ def main(**kwargs):
if
not
train_config
.
use_peft
and
train_config
.
freeze_layers
:
if
not
train_config
.
use_peft
and
train_config
.
freeze_layers
:
freeze_transformer_layers
(
model
,
train_config
.
num_freeze_layers
)
freeze_transformer_layers
(
model
,
train_config
.
num_freeze_layers
)
if
not
train_config
.
use_peft
and
train_config
.
freeze_LLM_only
and
config
.
model_type
==
"
mllama
"
:
freeze_LLM_only
(
model
)
mixed_precision_policy
,
wrapping_policy
=
get_policies
(
fsdp_config
,
rank
)
mixed_precision_policy
,
wrapping_policy
=
get_policies
(
fsdp_config
,
rank
)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
...
@@ -255,6 +258,11 @@ def main(**kwargs):
...
@@ -255,6 +258,11 @@ def main(**kwargs):
device_id
=
torch
.
xpu
.
current_device
()
device_id
=
torch
.
xpu
.
current_device
()
elif
torch
.
cuda
.
is_available
():
elif
torch
.
cuda
.
is_available
():
device_id
=
torch
.
cuda
.
current_device
()
device_id
=
torch
.
cuda
.
current_device
()
if
train_config
.
freeze_LLM_only
:
use_orig_params
=
True
else
:
use_orig_params
=
False
model
=
FSDP
(
model
=
FSDP
(
model
,
model
,
auto_wrap_policy
=
(
auto_wrap_policy
=
(
...
@@ -282,6 +290,7 @@ def main(**kwargs):
...
@@ -282,6 +290,7 @@ def main(**kwargs):
if
train_config
.
low_cpu_fsdp
and
rank
!=
0
if
train_config
.
low_cpu_fsdp
and
rank
!=
0
else
None
else
None
),
),
use_orig_params
=
use_orig_params
,
)
)
if
fsdp_config
.
fsdp_activation_checkpointing
:
if
fsdp_config
.
fsdp_activation_checkpointing
:
model
.
enable_input_require_grads
()
model
.
enable_input_require_grads
()
...
@@ -298,6 +307,8 @@ def main(**kwargs):
...
@@ -298,6 +307,8 @@ def main(**kwargs):
else
:
else
:
dataset_processer
=
tokenizer
dataset_processer
=
tokenizer
print_model_size
(
model
,
train_config
,
rank
if
train_config
.
enable_fsdp
else
0
)
# Load and preprocess the dataset for training and validation
# Load and preprocess the dataset for training and validation
dataset_train
=
get_preprocessed_dataset
(
dataset_train
=
get_preprocessed_dataset
(
...
...
This diff is collapsed.
Click to expand it.
src/llama_recipes/utils/train_utils.py
+
11
−
1
View file @
21e8368c
...
@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
...
@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
if
i
<
num_layer
:
if
i
<
num_layer
:
for
param
in
layer
.
parameters
():
for
param
in
layer
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
def
freeze_LLM_only
(
model
):
"""
Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
"""
for
name
,
param
in
model
.
language_model
.
named_parameters
():
param
.
requires_grad
=
False
for
i
,
layer
in
enumerate
(
model
.
language_model
.
model
.
layers
):
if
i
in
model
.
language_model
.
model
.
cross_attention_layers
:
for
param
in
layer
.
parameters
():
param
.
requires_grad
=
True
def
check_frozen_layers_peft_model
(
model
):
def
check_frozen_layers_peft_model
(
model
):
for
i
,
layer
in
enumerate
(
model
.
base_model
.
model
.
model
.
layers
):
for
i
,
layer
in
enumerate
(
model
.
base_model
.
model
.
model
.
layers
):
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment