diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1653dfa5b88891192e8785a568cff4bf0105c53 --- /dev/null +++ b/tests/test_train_utils.py @@ -0,0 +1,48 @@ +import torch + +from llama_recipes.utils.train_utils import train + +def test_gradient_accumulation(mocker): + # import sys + # sys.path.append('/home/ubuntu/llama-recipes/') + + model = mocker.MagicMock(name="model") + model().loss.__truediv__().detach.return_value = torch.tensor(1) + batch = {"input": torch.zeros(1)} + train_dataloader = [batch, batch, batch, batch, batch] + eval_dataloader = None + tokenizer = mocker.MagicMock() + optimizer = mocker.MagicMock() + lr_scheduler = mocker.MagicMock() + gradient_accumulation_steps = 1 + train_config = mocker.MagicMock() + train_config.enable_fsdp = False + train_config.use_fp16 = False + train_config.run_validation = False + + train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + optimizer, + lr_scheduler, + gradient_accumulation_steps, + train_config, + ) + + assert optimizer.zero_grad.call_count == 5 + optimizer.zero_grad.reset_mock() + + gradient_accumulation_steps = 2 + train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + optimizer, + lr_scheduler, + gradient_accumulation_steps, + train_config, + ) + assert optimizer.zero_grad.call_count == 3 \ No newline at end of file