From 4c225c65eb9e24fb93de733cf5418d46f2653bcf Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Tue, 3 Oct 2023 02:56:44 -0700 Subject: [PATCH] Fix order of concat vs sampler --- src/llama_recipes/finetuning.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index b899ec29..2ec5c234 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -179,12 +179,11 @@ def main(**kwargs): if not train_config.enable_fsdp or rank == 0: print(f"--> Validation Set Length = {len(dataset_val)}") - train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") - val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") - if train_config.batching_strategy == "packing": dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) + train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") + # Create DataLoaders for the training and validation dataset train_dataloader = torch.utils.data.DataLoader( dataset_train, @@ -197,6 +196,9 @@ def main(**kwargs): if train_config.run_validation: if train_config.batching_strategy == "packing": dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length) + + val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") + eval_dataloader = torch.utils.data.DataLoader( dataset_val, num_workers=train_config.num_workers_dataloader, -- GitLab