From 5b58afc75443df438e623e51f640e439e64fc9f9 Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Wed, 30 Aug 2023 22:13:37 +0000
Subject: [PATCH] Fix div by zero if run_validation=False

---
 src/llama_recipes/finetuning.py        | 1 +
 src/llama_recipes/utils/train_utils.py | 5 ++---
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 7715396c..c391d75c 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -203,6 +203,7 @@ def main(**kwargs):
         collate_fn=default_data_collator,
     )
 
+    eval_dataloader = None
     if train_config.run_validation:
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py
index fc65553f..a69e947a 100644
--- a/src/llama_recipes/utils/train_utils.py
+++ b/src/llama_recipes/utils/train_utils.py
@@ -179,14 +179,13 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     print(f"best eval loss on epoch {epoch} is {best_val_loss}")
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
-        
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
         else:
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
-    avg_epoch_time = sum(epoch_times)/ len(epoch_times) 
-    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)   
+    avg_epoch_time = sum(epoch_times)/ len(epoch_times)
+    avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
-- 
GitLab