From 4c225c65eb9e24fb93de733cf5418d46f2653bcf Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Tue, 3 Oct 2023 02:56:44 -0700
Subject: [PATCH] Fix order of concat vs sampler

---
 src/llama_recipes/finetuning.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index b899ec29..2ec5c234 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -179,12 +179,11 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-    val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
-
     if train_config.batching_strategy == "packing":
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -197,6 +196,9 @@ def main(**kwargs):
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
             dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             num_workers=train_config.num_workers_dataloader,
-- 
GitLab