diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 7715396cd19a07b0f36ba2f11f8ee738952cd5bf..c391d75ca42f6e7302fb37f3a7b4d89d8db08cdd 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -203,6 +203,7 @@ def main(**kwargs): collate_fn=default_data_collator, ) + eval_dataloader = None if train_config.run_validation: eval_dataloader = torch.utils.data.DataLoader( dataset_val, diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index fc65553f0887f7791df1a9c23a5ff977d489b730..a69e947a19f129184a9f40c26fc3c1f60108657e 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -179,14 +179,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche print(f"best eval loss on epoch {epoch} is {best_val_loss}") val_loss.append(best_val_loss) val_prep.append(eval_ppl) - if train_config.enable_fsdp: if rank==0: print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") else: print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") - avg_epoch_time = sum(epoch_times)/ len(epoch_times) - avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) + avg_epoch_time = sum(epoch_times)/ len(epoch_times) + avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 avg_train_prep = sum(train_prep)/len(train_prep) avg_train_loss = sum(train_loss)/len(train_loss) if train_config.run_validation: