From aa5dee241a4473323a47f0c46379827932f85e51 Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Mon, 2 Oct 2023 13:06:54 -0700
Subject: [PATCH] Fix unit test to reflect batch packing

---
 tests/test_finetuning.py | 25 +++++++++++++++++++------
 1 file changed, 19 insertions(+), 6 deletions(-)

diff --git a/tests/test_finetuning.py b/tests/test_finetuning.py
index 6651a22a..e0d6b324 100644
--- a/tests/test_finetuning.py
+++ b/tests/test_finetuning.py
@@ -13,6 +13,15 @@ from torch.utils.data.sampler import BatchSampler
 from llama_recipes.finetuning import main
 from llama_recipes.data.sampler import LengthBasedBatchSampler
 
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -22,7 +31,7 @@ from llama_recipes.data.sampler import LengthBasedBatchSampler
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)
 
@@ -46,7 +55,8 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [[1]]
+
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)
 
@@ -72,7 +82,7 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)
 
@@ -89,7 +99,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
     kwargs = {"weight_decay": 0.01}
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     get_peft_model.return_value = Linear(1,1)
     get_peft_model.return_value.print_trainable_parameters=lambda:None
@@ -113,9 +123,12 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
-    kwargs = {"batching_strategy": "packing"}
+    kwargs = {
+        "batching_strategy": "packing",
+        "use_peft": False,
+        }
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)
 
-- 
GitLab