From 5da84b29139ece3a8b0c5c6187f81e7feaa3c005 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 2 Oct 2023 13:14:50 -0700 Subject: [PATCH] Fix usage of dataclass for train_config and fsdp_config --- src/llama_recipes/finetuning.py | 4 +++- tests/test_finetuning.py | 9 +++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index ec035d75..b899ec29 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -21,7 +21,8 @@ from transformers import ( ) from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from llama_recipes.configs import fsdp_config, train_config +from llama_recipes.configs import fsdp_config as FSDP_CONFIG +from llama_recipes.configs import train_config as TRAIN_CONFIG from llama_recipes.data.concatenator import ConcatDataset from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing @@ -47,6 +48,7 @@ from llama_recipes.utils.train_utils import ( def main(**kwargs): # Update the configuration for the training and sharding process + train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() update_config((train_config, fsdp_config), **kwargs) # Set the seeds for reproducibility diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py index e0d6b324..45308bb9 100644 --- a/tests/test_finetuning.py +++ b/tests/test_finetuning.py @@ -101,8 +101,8 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer get_dataset.return_value = get_fake_dataset() - get_peft_model.return_value = Linear(1,1) - get_peft_model.return_value.print_trainable_parameters=lambda:None + get_model.return_value = Linear(1,1) + main(**kwargs) assert train.call_count == 1 @@ -123,10 +123,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train): - kwargs = { - "batching_strategy": "packing", - "use_peft": False, - } + kwargs = {"batching_strategy": "packing"} get_dataset.return_value = get_fake_dataset() -- GitLab