From 41a46d811d2b61922cf449550230e74958dd688f Mon Sep 17 00:00:00 2001
From: Kai Wu <kaiwu@meta.com>
Date: Thu, 23 May 2024 10:37:22 -0700
Subject: [PATCH] fix alpaca dataset by using 5% of the data as eval and make
 sure len((eval_loader)>0

---
 src/llama_recipes/datasets/alpaca_dataset.py | 6 ++++--
 src/llama_recipes/finetuning.py              | 4 ++++
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/src/llama_recipes/datasets/alpaca_dataset.py b/src/llama_recipes/datasets/alpaca_dataset.py
index 21bd9643..396551d1 100644
--- a/src/llama_recipes/datasets/alpaca_dataset.py
+++ b/src/llama_recipes/datasets/alpaca_dataset.py
@@ -26,10 +26,12 @@ PROMPT_DICT = {
 class InstructionDataset(Dataset):
     def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
+        # Use 5% of the dataset for evaluation
+        eval_length = int(len(self.ann)/20)
         if partition == "train":
-            self.ann = self.ann[200:]
+            self.ann = self.ann[eval_length:]
         else:
-            self.ann = self.ann[:200]
+            self.ann = self.ann[:eval_length]
 
         self.tokenizer = tokenizer
 
diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py
index 3fef7222..0bfea288 100644
--- a/src/llama_recipes/finetuning.py
+++ b/src/llama_recipes/finetuning.py
@@ -250,6 +250,10 @@ def main(**kwargs):
             pin_memory=True,
             **val_dl_kwargs,
         )
+        if len(eval_dataloader) == 0:
+            raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
+        else:
+            print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
 
     # Initialize the optimizer and learning rate scheduler
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
-- 
GitLab