From f398bc54c93ab007908ee82bccc9fb209dda4f54 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:14:12 +0000 Subject: [PATCH] Added basic unit test for train method --- tests/test_train_utils.py | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/test_train_utils.py diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py new file mode 100644 index 00000000..d1653dfa --- /dev/null +++ b/tests/test_train_utils.py @@ -0,0 +1,48 @@ +import torch + +from llama_recipes.utils.train_utils import train + +def test_gradient_accumulation(mocker): + # import sys + # sys.path.append('/home/ubuntu/llama-recipes/') + + model = mocker.MagicMock(name="model") + model().loss.__truediv__().detach.return_value = torch.tensor(1) + batch = {"input": torch.zeros(1)} + train_dataloader = [batch, batch, batch, batch, batch] + eval_dataloader = None + tokenizer = mocker.MagicMock() + optimizer = mocker.MagicMock() + lr_scheduler = mocker.MagicMock() + gradient_accumulation_steps = 1 + train_config = mocker.MagicMock() + train_config.enable_fsdp = False + train_config.use_fp16 = False + train_config.run_validation = False + + train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + optimizer, + lr_scheduler, + gradient_accumulation_steps, + train_config, + ) + + assert optimizer.zero_grad.call_count == 5 + optimizer.zero_grad.reset_mock() + + gradient_accumulation_steps = 2 + train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + optimizer, + lr_scheduler, + gradient_accumulation_steps, + train_config, + ) + assert optimizer.zero_grad.call_count == 3 \ No newline at end of file -- GitLab