diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 0e140a7971115b77000ced3cd7d04cdecd5c0e64..867ac4738e73c816dc2e5f6611975ee2ce292cf1 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -287,7 +287,7 @@ def main(**kwargs): ) print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}") if len(eval_dataloader) == 0: - raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.") + raise ValueError(f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})") else: print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}") diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 710ed74042ffdddc95f5ac328af980976c4d7c55..33e118fdefe0be3b436b7d2f5dc9579bf0fa4e08 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -13,6 +13,11 @@ def llama_version(request): return request.param +@pytest.fixture(params=["mllama", "llama"]) +def model_type(request): + return request.param + + @pytest.fixture(scope="module") def llama_tokenizer(request): return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS} diff --git a/src/tests/test_finetuning.py b/src/tests/test_finetuning.py index 749f8614f706ac1cd13492c34d2aa68089490a35..b695577523bc75884950dd515cbffde316aed2ed 100644 --- a/src/tests/test_finetuning.py +++ b/src/tests/test_finetuning.py @@ -2,6 +2,8 @@ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os +from contextlib import nullcontext +from dataclasses import dataclass from unittest.mock import patch import pytest @@ -16,8 +18,12 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.sampler import BatchSampler +@dataclass +class Config: + model_type: str = "llama" + def get_fake_dataset(): - return [ + return 8192*[ { "input_ids": [1], "attention_mask": [1], @@ -28,28 +34,49 @@ def get_fake_dataset(): @patch("llama_recipes.finetuning.torch.cuda.is_available") @patch("llama_recipes.finetuning.train") +@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained") +@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained") @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained") +@patch("llama_recipes.finetuning.AutoConfig.from_pretrained") @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained") @patch("llama_recipes.finetuning.get_preprocessed_dataset") +@patch("llama_recipes.finetuning.generate_peft_config") +@patch("llama_recipes.finetuning.get_peft_model") @patch("llama_recipes.finetuning.optim.AdamW") @patch("llama_recipes.finetuning.StepLR") @pytest.mark.parametrize("cuda_is_available", [True, False]) -def test_finetuning_no_validation( +@pytest.mark.parametrize("run_validation", [True, False]) +@pytest.mark.parametrize("use_peft", [True, False]) +def test_finetuning( step_lr, optimizer, + get_peft_model, + gen_peft_config, get_dataset, tokenizer, + get_config, get_model, + get_processor, + get_mmodel, train, cuda, cuda_is_available, + run_validation, + use_peft, + model_type, ): - kwargs = {"run_validation": False} + kwargs = { + "run_validation": run_validation, + "use_peft": use_peft, + "batching_strategy": "packing" if model_type == "llama" else "padding", + } get_dataset.return_value = get_fake_dataset() cuda.return_value = cuda_is_available get_model.return_value.get_input_embeddings.return_value.weight.shape = [0] + get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0] + get_config.return_value = Config(model_type=model_type) main(**kwargs) @@ -60,115 +87,99 @@ def test_finetuning_no_validation( eval_dataloader = args[2] assert isinstance(train_dataloader, DataLoader) - assert eval_dataloader is None - - if cuda_is_available: - assert get_model.return_value.to.call_count == 1 - assert get_model.return_value.to.call_args.args[0] == "cuda" + if run_validation: + assert isinstance(eval_dataloader, DataLoader) else: - assert get_model.return_value.to.call_count == 0 - - -@patch("llama_recipes.finetuning.torch.cuda.is_available") -@patch("llama_recipes.finetuning.train") -@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained") -@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained") -@patch("llama_recipes.finetuning.get_preprocessed_dataset") -@patch("llama_recipes.finetuning.optim.AdamW") -@patch("llama_recipes.finetuning.StepLR") -@pytest.mark.parametrize("cuda_is_available", [True, False]) -def test_finetuning_with_validation( - step_lr, - optimizer, - get_dataset, - tokenizer, - get_model, - train, - cuda, - cuda_is_available, -): - kwargs = {"run_validation": True} - - get_dataset.return_value = get_fake_dataset() - cuda.return_value = cuda_is_available - - get_model.return_value.get_input_embeddings.return_value.weight.shape = [0] - - main(**kwargs) - - assert train.call_count == 1 - - args, kwargs = train.call_args - train_dataloader = args[1] - eval_dataloader = args[2] - assert isinstance(train_dataloader, DataLoader) - assert isinstance(eval_dataloader, DataLoader) + assert eval_dataloader is None - if cuda_is_available: - assert get_model.return_value.to.call_count == 1 - assert get_model.return_value.to.call_args.args[0] == "cuda" + if use_peft: + assert get_peft_model.return_value.print_trainable_parameters.call_count == 1 + model = get_peft_model + elif model_type == "llama": + model = get_model else: - assert get_model.return_value.to.call_count == 0 - - -@patch("llama_recipes.finetuning.torch.cuda.is_available") -@patch("llama_recipes.finetuning.train") -@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained") -@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained") -@patch("llama_recipes.finetuning.get_preprocessed_dataset") -@patch("llama_recipes.finetuning.generate_peft_config") -@patch("llama_recipes.finetuning.get_peft_model") -@patch("llama_recipes.finetuning.optim.AdamW") -@patch("llama_recipes.finetuning.StepLR") -@pytest.mark.parametrize("cuda_is_available", [True, False]) -def test_finetuning_peft_lora( - step_lr, - optimizer, - get_peft_model, - gen_peft_config, - get_dataset, - tokenizer, - get_model, - train, - cuda, - cuda_is_available, -): - kwargs = {"use_peft": True} - - get_dataset.return_value = get_fake_dataset() - cuda.return_value = cuda_is_available - - get_model.return_value.get_input_embeddings.return_value.weight.shape = [0] - - main(**kwargs) + model = get_mmodel if cuda_is_available: - assert get_peft_model.return_value.to.call_count == 1 - assert get_peft_model.return_value.to.call_args.args[0] == "cuda" + assert model.return_value.to.call_count == 1 + assert model.return_value.to.call_args.args[0] == "cuda" else: - assert get_peft_model.return_value.to.call_count == 0 - - assert get_peft_model.return_value.print_trainable_parameters.call_count == 1 + assert model.return_value.to.call_count == 0 + + +# @patch("llama_recipes.finetuning.torch.cuda.is_available") +# @patch("llama_recipes.finetuning.train") +# @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained") +# @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained") +# @patch("llama_recipes.finetuning.get_preprocessed_dataset") +# @patch("llama_recipes.finetuning.generate_peft_config") +# @patch("llama_recipes.finetuning.get_peft_model") +# @patch("llama_recipes.finetuning.optim.AdamW") +# @patch("llama_recipes.finetuning.StepLR") +# @pytest.mark.parametrize("cuda_is_available", [True, False]) +# def test_finetuning_peft_lora( +# step_lr, +# optimizer, +# get_peft_model, +# gen_peft_config, +# get_dataset, +# tokenizer, +# get_model, +# train, +# cuda, +# cuda_is_available, +# ): +# kwargs = {"use_peft": True} + +# get_dataset.return_value = get_fake_dataset() +# cuda.return_value = cuda_is_available + +# get_model.return_value.get_input_embeddings.return_value.weight.shape = [0] + +# main(**kwargs) + +# if cuda_is_available: +# assert get_peft_model.return_value.to.call_count == 1 +# assert get_peft_model.return_value.to.call_args.args[0] == "cuda" +# else: +# assert get_peft_model.return_value.to.call_count == 0 + + @patch("llama_recipes.finetuning.get_peft_model") @patch("llama_recipes.finetuning.setup") @patch("llama_recipes.finetuning.train") +@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained") +@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained") @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained") +@patch("llama_recipes.finetuning.AutoConfig.from_pretrained") @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained") @patch("llama_recipes.finetuning.get_preprocessed_dataset") def test_finetuning_peft_llama_adapter( - get_dataset, tokenizer, get_model, train, setup, get_peft_model + get_dataset, + tokenizer, + get_config, + get_model, + get_processor, + get_mmodel, + train, + setup, + get_peft_model, + model_type, ): kwargs = { "use_peft": True, "peft_method": "llama_adapter", "enable_fsdp": True, + "batching_strategy": "packing" if model_type == "llama" else "padding", } get_dataset.return_value = get_fake_dataset() get_model.return_value.get_input_embeddings.return_value.weight.shape = [0] + get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0] + get_config.return_value = Config(model_type=model_type) os.environ["RANK"] = "0" os.environ["LOCAL_RANK"] = "0" @@ -195,20 +206,38 @@ def test_finetuning_peft_llama_adapter( @patch("llama_recipes.finetuning.train") +@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained") +@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained") @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained") +@patch("llama_recipes.finetuning.AutoConfig.from_pretrained") @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained") @patch("llama_recipes.finetuning.get_preprocessed_dataset") @patch("llama_recipes.finetuning.get_peft_model") @patch("llama_recipes.finetuning.StepLR") def test_finetuning_weight_decay( - step_lr, get_peft_model, get_dataset, tokenizer, get_model, train + step_lr, + get_peft_model, + get_dataset, + tokenizer, + get_config, + get_model, + get_processor, + get_mmodel, + train, + model_type, ): - kwargs = {"weight_decay": 0.01} + kwargs = { + "weight_decay": 0.01, + "batching_strategy": "packing" if model_type == "llama" else "padding", + } get_dataset.return_value = get_fake_dataset() - get_model.return_value.parameters.return_value = [torch.ones(1, 1)] - get_model.return_value.get_input_embeddings.return_value.weight.shape = [0] + model = get_model if model_type == "llama" else get_mmodel + model.return_value.parameters.return_value = [torch.ones(1, 1)] + model.return_value.get_input_embeddings.return_value.weight.shape = [0] + + get_config.return_value = Config(model_type=model_type) main(**kwargs) @@ -224,28 +253,49 @@ def test_finetuning_weight_decay( @patch("llama_recipes.finetuning.train") +@patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained") +@patch("llama_recipes.finetuning.AutoProcessor.from_pretrained") @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained") +@patch("llama_recipes.finetuning.AutoConfig.from_pretrained") @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained") @patch("llama_recipes.finetuning.get_preprocessed_dataset") @patch("llama_recipes.finetuning.optim.AdamW") @patch("llama_recipes.finetuning.StepLR") def test_batching_strategy( - step_lr, optimizer, get_dataset, tokenizer, get_model, train + step_lr, + optimizer, + get_dataset, + tokenizer, + get_config, + get_model, + get_processor, + get_mmodel, + train, + model_type, ): - kwargs = {"batching_strategy": "packing"} + kwargs = { + "batching_strategy": "packing", + } get_dataset.return_value = get_fake_dataset() - get_model.return_value.get_input_embeddings.return_value.weight.shape = [0] + model = get_model if model_type == "llama" else get_mmodel + model.return_value.get_input_embeddings.return_value.weight.shape = [0] - main(**kwargs) + get_config.return_value = Config(model_type=model_type) - assert train.call_count == 1 + c = nullcontext() if model_type == "llama" else pytest.raises(ValueError) + + with c: + main(**kwargs) - args, kwargs = train.call_args - train_dataloader, eval_dataloader = args[1:3] - assert isinstance(train_dataloader.batch_sampler, BatchSampler) - assert isinstance(eval_dataloader.batch_sampler, BatchSampler) + assert train.call_count == (1 if model_type == "llama" else 0) + + if model_type == "llama": + args, kwargs = train.call_args + train_dataloader, eval_dataloader = args[1:3] + assert isinstance(train_dataloader.batch_sampler, BatchSampler) + assert isinstance(eval_dataloader.batch_sampler, BatchSampler) kwargs["batching_strategy"] = "padding" train.reset_mock()