From 1de6ac3177b948ba3b48c579fc762e01c5b782f1 Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Mon, 14 Oct 2024 15:17:07 -0700
Subject: [PATCH] mock samsum in test_batching

---
 src/tests/test_batching.py | 27 ++++++++++++++++++++-------
 1 file changed, 20 insertions(+), 7 deletions(-)

diff --git a/src/tests/test_batching.py b/src/tests/test_batching.py
index b51e8691..5aed0a4c 100644
--- a/src/tests/test_batching.py
+++ b/src/tests/test_batching.py
@@ -2,8 +2,9 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import pytest
-from dataclasses import dataclass
 from contextlib import nullcontext
+from dataclasses import dataclass
+from datasets import Dataset
 from unittest.mock import patch
 
 @dataclass
@@ -12,19 +13,23 @@ class Config:
 
 EXPECTED_SAMPLE_NUMBER ={
     "meta-llama/Llama-2-7b-hf": {
-        "train": 96,
-        "eval": 42,
+        "train": 4,
+        "eval": 37,
     },
     "meta-llama/Meta-Llama-3.1-8B-Instruct": {
-        "train": 79,
-        "eval": 34,
+        "train": 3,
+        "eval": 30,
     },
     "fake_llama": {
-        "train": 50,
-        "eval": 21,
+        "train": 2,
+        "eval": 17,
     }
 }
 
+fake_samsum_dataset = 2048*[{'id': '420',
+ 'dialogue': "Mario: It's a me, Mario!\nLuigi: It's a me, your brother!\nMario: I'm going to save the princess.\nLuigi: I'm going to help Mario.",
+ 'summary': 'Mario and Luigi are going to save the princess.'}]
+
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
@@ -34,7 +39,9 @@ EXPECTED_SAMPLE_NUMBER ={
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
+@patch('llama_recipes.datasets.samsum_dataset.datasets')
 def test_packing(
+    datasets,
     step_lr,
     optimizer,
     get_model,
@@ -55,6 +62,8 @@ def test_packing(
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
     get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
     get_config.return_value = Config(model_type=model_type)
+
+    datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
     
     kwargs = {
         "model_name": llama_version,
@@ -106,7 +115,9 @@ def test_packing(
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
+@patch('llama_recipes.datasets.samsum_dataset.datasets')
 def test_distributed_packing(
+    datasets,
     dist,
     is_initialized,
     fsdp,
@@ -137,6 +148,8 @@ def test_distributed_packing(
     cuda_is_available.return_value = False
     cuda_is_bf16_supported.return_value = False
 
+    datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
+
     rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
     os.environ['RANK'] = f'{rank}'
-- 
GitLab