Skip to content
Snippets Groups Projects
Commit 5b58afc7 authored by Matthias Reso's avatar Matthias Reso
Browse files

Fix div by zero if run_validation=False

parent 27170481
No related branches found
No related tags found
No related merge requests found
...@@ -203,6 +203,7 @@ def main(**kwargs): ...@@ -203,6 +203,7 @@ def main(**kwargs):
collate_fn=default_data_collator, collate_fn=default_data_collator,
) )
eval_dataloader = None
if train_config.run_validation: if train_config.run_validation:
eval_dataloader = torch.utils.data.DataLoader( eval_dataloader = torch.utils.data.DataLoader(
dataset_val, dataset_val,
......
...@@ -179,14 +179,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche ...@@ -179,14 +179,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"best eval loss on epoch {epoch} is {best_val_loss}") print(f"best eval loss on epoch {epoch} is {best_val_loss}")
val_loss.append(best_val_loss) val_loss.append(best_val_loss)
val_prep.append(eval_ppl) val_prep.append(eval_ppl)
if train_config.enable_fsdp: if train_config.enable_fsdp:
if rank==0: if rank==0:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
else: else:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s") print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
avg_epoch_time = sum(epoch_times)/ len(epoch_times) avg_epoch_time = sum(epoch_times)/ len(epoch_times)
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
avg_train_prep = sum(train_prep)/len(train_prep) avg_train_prep = sum(train_prep)/len(train_prep)
avg_train_loss = sum(train_loss)/len(train_loss) avg_train_loss = sum(train_loss)/len(train_loss)
if train_config.run_validation: if train_config.run_validation:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment