From dd8ca3c211b67dbfd12bbf7d5fc3fbf608349921 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Wed, 25 Sep 2024 16:46:57 -0700 Subject: [PATCH] Fix test_custom_dataset.py --- src/tests/datasets/test_custom_dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/tests/datasets/test_custom_dataset.py b/src/tests/datasets/test_custom_dataset.py index 7cf8abe3..af424335 100644 --- a/src/tests/datasets/test_custom_dataset.py +++ b/src/tests/datasets/test_custom_dataset.py @@ -37,7 +37,7 @@ def check_padded_entry(batch, tokenizer): @pytest.mark.skip_missing_tokenizer @patch('llama_recipes.finetuning.train') @patch('llama_recipes.finetuning.AutoTokenizer') -@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') +@patch('llama_recipes.finetuning.AutoModel.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version): @@ -96,15 +96,17 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, @patch('llama_recipes.finetuning.train') -@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained') +@patch('llama_recipes.finetuning.AutoConfig.from_pretrained') +@patch('llama_recipes.finetuning.AutoModel.from_pretrained') @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained') @patch('llama_recipes.finetuning.optim.AdamW') @patch('llama_recipes.finetuning.StepLR') -def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker, llama_version): +def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, get_config, train, mocker, llama_version): from llama_recipes.finetuning import main tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]}) get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256] + get_config.return_value.model_type = "llama" kwargs = { "dataset": "custom_dataset", -- GitLab