diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index ec035d7564e4f358ab6b5279eb5d8ffd0b40098e..b899ec2905b8f5934fd92242391accd495aa6339 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -21,7 +21,8 @@ from transformers import ( ) from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from llama_recipes.configs import fsdp_config, train_config +from llama_recipes.configs import fsdp_config as FSDP_CONFIG +from llama_recipes.configs import train_config as TRAIN_CONFIG from llama_recipes.data.concatenator import ConcatDataset from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing @@ -47,6 +48,7 @@ from llama_recipes.utils.train_utils import ( def main(**kwargs): # Update the configuration for the training and sharding process + train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() update_config((train_config, fsdp_config), **kwargs) # Set the seeds for reproducibility diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index e0d6b324aee89478cb0fa67f7af4cea8f138ab76..45308bb9e10675d82c0501ba61bf4c4ef07691d0 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -101,8 +101,8 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer get_dataset.return_value = get_fake_dataset() - get_peft_model.return_value = Linear(1,1) - get_peft_model.return_value.print_trainable_parameters=lambda:None + get_model.return_value = Linear(1,1) + main(**kwargs) assert train.call_count == 1 @@ -123,10 +123,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train): - kwargs = { - "batching_strategy": "packing", - "use_peft": False, - } + kwargs = {"batching_strategy": "packing"} get_dataset.return_value = get_fake_dataset()