Skip to content
Snippets Groups Projects
Commit 1de6ac31 authored by Matthias Reso's avatar Matthias Reso
Browse files

mock samsum in test_batching

parent 26dff882
No related branches found
No related tags found
No related merge requests found
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import pytest import pytest
from dataclasses import dataclass
from contextlib import nullcontext from contextlib import nullcontext
from dataclasses import dataclass
from datasets import Dataset
from unittest.mock import patch from unittest.mock import patch
@dataclass @dataclass
...@@ -12,19 +13,23 @@ class Config: ...@@ -12,19 +13,23 @@ class Config:
EXPECTED_SAMPLE_NUMBER ={ EXPECTED_SAMPLE_NUMBER ={
"meta-llama/Llama-2-7b-hf": { "meta-llama/Llama-2-7b-hf": {
"train": 96, "train": 4,
"eval": 42, "eval": 37,
}, },
"meta-llama/Meta-Llama-3.1-8B-Instruct": { "meta-llama/Meta-Llama-3.1-8B-Instruct": {
"train": 79, "train": 3,
"eval": 34, "eval": 30,
}, },
"fake_llama": { "fake_llama": {
"train": 50, "train": 2,
"eval": 21, "eval": 17,
} }
} }
fake_samsum_dataset = 2048*[{'id': '420',
'dialogue': "Mario: It's a me, Mario!\nLuigi: It's a me, your brother!\nMario: I'm going to save the princess.\nLuigi: I'm going to help Mario.",
'summary': 'Mario and Luigi are going to save the princess.'}]
@pytest.mark.skip_missing_tokenizer @pytest.mark.skip_missing_tokenizer
@patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.AutoTokenizer') @patch('llama_recipes.finetuning.AutoTokenizer')
...@@ -34,7 +39,9 @@ EXPECTED_SAMPLE_NUMBER ={ ...@@ -34,7 +39,9 @@ EXPECTED_SAMPLE_NUMBER ={
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR') @patch('llama_recipes.finetuning.StepLR')
@patch('llama_recipes.datasets.samsum_dataset.datasets')
def test_packing( def test_packing(
datasets,
step_lr, step_lr,
optimizer, optimizer,
get_model, get_model,
...@@ -55,6 +62,8 @@ def test_packing( ...@@ -55,6 +62,8 @@ def test_packing(
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256] get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0] get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
get_config.return_value = Config(model_type=model_type) get_config.return_value = Config(model_type=model_type)
datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
kwargs = { kwargs = {
"model_name": llama_version, "model_name": llama_version,
...@@ -106,7 +115,9 @@ def test_packing( ...@@ -106,7 +115,9 @@ def test_packing(
@patch('llama_recipes.finetuning.FSDP') @patch('llama_recipes.finetuning.FSDP')
@patch('llama_recipes.finetuning.torch.distributed.is_initialized') @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
@patch('llama_recipes.utils.config_utils.dist') @patch('llama_recipes.utils.config_utils.dist')
@patch('llama_recipes.datasets.samsum_dataset.datasets')
def test_distributed_packing( def test_distributed_packing(
datasets,
dist, dist,
is_initialized, is_initialized,
fsdp, fsdp,
...@@ -137,6 +148,8 @@ def test_distributed_packing( ...@@ -137,6 +148,8 @@ def test_distributed_packing(
cuda_is_available.return_value = False cuda_is_available.return_value = False
cuda_is_bf16_supported.return_value = False cuda_is_bf16_supported.return_value = False
datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
rank = 1 rank = 1
os.environ['LOCAL_RANK'] = f'{rank}' os.environ['LOCAL_RANK'] = f'{rank}'
os.environ['RANK'] = f'{rank}' os.environ['RANK'] = f'{rank}'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment