From d9ca099613068fa04c77ea435d42e7c666961f18 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Tue, 15 Oct 2024 13:46:13 -0700 Subject: [PATCH] Fix fixture in test_train_utils --- src/tests/test_train_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tests/test_train_utils.py b/src/tests/test_train_utils.py index e8a40ffe..66e3e9f0 100644 --- a/src/tests/test_train_utils.py +++ b/src/tests/test_train_utils.py @@ -36,6 +36,7 @@ def test_gradient_accumulation( model = mocker.MagicMock(name="model") model().loss.__truediv__().detach.return_value = torch.tensor(1) + model().loss.detach.return_value = torch.tensor(1) mock_tensor = mocker.MagicMock(name="tensor") batch = {"input": mock_tensor} train_dataloader = [batch, batch, batch, batch, batch] @@ -94,6 +95,7 @@ def test_gradient_accumulation( def test_save_to_json(temp_output_dir, mocker): model = mocker.MagicMock(name="model") model().loss.__truediv__().detach.return_value = torch.tensor(1) + model().loss.detach.return_value = torch.tensor(1) mock_tensor = mocker.MagicMock(name="tensor") batch = {"input": mock_tensor} train_dataloader = [batch, batch, batch, batch, batch] -- GitLab