diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index ec035d7564e4f358ab6b5279eb5d8ffd0b40098e..b899ec2905b8f5934fd92242391accd495aa6339 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -21,7 +21,8 @@ from transformers import (
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
@@ -47,6 +48,7 @@ from llama_recipes.utils.train_utils import (
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py
index e0d6b324aee89478cb0fa67f7af4cea8f138ab76..45308bb9e10675d82c0501ba61bf4c4ef07691d0 100644
--- a/tests/test_finetuning.py
+++ b/tests/test_finetuning.py
@@ -101,8 +101,8 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
     get_dataset.return_value = get_fake_dataset()
 
-    get_peft_model.return_value = Linear(1,1)
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
 
     assert train.call_count == 1
@@ -123,10 +123,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
-    kwargs = {
-        "batching_strategy": "packing",
-        "use_peft": False,
-        }
+    kwargs = {"batching_strategy": "packing"}
 
     get_dataset.return_value = get_fake_dataset()