diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index b899ec2905b8f5934fd92242391accd495aa6339..2ec5c2340fb067136b41128271d55bebf6577be7 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -179,12 +179,11 @@ def main(**kwargs): if not train_config.enable_fsdp or rank == 0: print(f"--> Validation Set Length = {len(dataset_val)}") - train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") - val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") - if train_config.batching_strategy == "packing": dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) + train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") + # Create DataLoaders for the training and validation dataset train_dataloader = torch.utils.data.DataLoader( dataset_train, @@ -197,6 +196,9 @@ def main(**kwargs): if train_config.run_validation: if train_config.batching_strategy == "packing": dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length) + + val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") + eval_dataloader = torch.utils.data.DataLoader( dataset_val, num_workers=train_config.num_workers_dataloader,