From 5da84b29139ece3a8b0c5c6187f81e7feaa3c005 Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Mon, 2 Oct 2023 13:14:50 -0700
Subject: [PATCH] Fix usage of dataclass for train_config and fsdp_config

---
 src/llama_recipes/finetuning.py | 4 +++-
 tests/test_finetuning.py        | 9 +++------
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index ec035d75..b899ec29 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -21,7 +21,8 @@ from transformers import (
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
@@ -47,6 +48,7 @@ from llama_recipes.utils.train_utils import (
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py
index e0d6b324..45308bb9 100644
--- a/tests/test_finetuning.py
+++ b/tests/test_finetuning.py
@@ -101,8 +101,8 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
     get_dataset.return_value = get_fake_dataset()
 
-    get_peft_model.return_value = Linear(1,1)
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
 
     assert train.call_count == 1
@@ -123,10 +123,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
-    kwargs = {
-        "batching_strategy": "packing",
-        "use_peft": False,
-        }
+    kwargs = {"batching_strategy": "packing"}
 
     get_dataset.return_value = get_fake_dataset()
 
-- 
GitLab