From 1de6ac3177b948ba3b48c579fc762e01c5b782f1 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Mon, 14 Oct 2024 15:17:07 -0700 Subject: [PATCH] mock samsum in test_batching --- src/tests/test_batching.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/tests/test_batching.py b/src/tests/test_batching.py index b51e8691..5aed0a4c 100644 --- a/src/tests/test_batching.py +++ b/src/tests/test_batching.py @@ -2,8 +2,9 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import pytest -from dataclasses import dataclass from contextlib import nullcontext +from dataclasses import dataclass +from datasets import Dataset from unittest.mock import patch @dataclass @@ -12,19 +13,23 @@ class Config: EXPECTED_SAMPLE_NUMBER ={ "meta-llama/Llama-2-7b-hf": { - "train": 96, - "eval": 42, + "train": 4, + "eval": 37, }, "meta-llama/Meta-Llama-3.1-8B-Instruct": { - "train": 79, - "eval": 34, + "train": 3, + "eval": 30, }, "fake_llama": { - "train": 50, - "eval": 21, + "train": 2, + "eval": 17, } } +fake_samsum_dataset = 2048*[{'id': '420', + 'dialogue': "Mario: It's a me, Mario!\nLuigi: It's a me, your brother!\nMario: I'm going to save the princess.\nLuigi: I'm going to help Mario.", + 'summary': 'Mario and Luigi are going to save the princess.'}] + @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.AutoTokenizer') @@ -34,7 +39,9 @@ EXPECTED_SAMPLE_NUMBER ={ @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') +@patch('llama_recipes.datasets.samsum_dataset.datasets') def test_packing( + datasets, step_lr, optimizer, get_model, @@ -55,6 +62,8 @@ def test_packing( get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256] get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0] get_config.return_value = Config(model_type=model_type) + + datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset) kwargs = { "model_name": llama_version, @@ -106,7 +115,9 @@ def test_packing( @patch('llama_recipes.finetuning.FSDP') @patch('llama_recipes.finetuning.torch.distributed.is_initialized') @patch('llama_recipes.utils.config_utils.dist') +@patch('llama_recipes.datasets.samsum_dataset.datasets') def test_distributed_packing( + datasets, dist, is_initialized, fsdp, @@ -137,6 +148,8 @@ def test_distributed_packing( cuda_is_available.return_value = False cuda_is_bf16_supported.return_value = False + datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset) + rank = 1 os.environ['LOCAL_RANK'] = f'{rank}' os.environ['RANK'] = f'{rank}' -- GitLab