diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index b899ec2905b8f5934fd92242391accd495aa6339..2ec5c2340fb067136b41128271d55bebf6577be7 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -179,12 +179,11 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-    val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
-
     if train_config.batching_strategy == "packing":
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -197,6 +196,9 @@ def main(**kwargs):
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
             dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             num_workers=train_config.num_workers_dataloader,