Skip to content
Snippets Groups Projects
Commit d9ca0996 authored by Matthias Reso's avatar Matthias Reso
Browse files

Fix fixture in test_train_utils

parent d58dea23
No related branches found
No related tags found
No related merge requests found
...@@ -36,6 +36,7 @@ def test_gradient_accumulation( ...@@ -36,6 +36,7 @@ def test_gradient_accumulation(
model = mocker.MagicMock(name="model") model = mocker.MagicMock(name="model")
model().loss.__truediv__().detach.return_value = torch.tensor(1) model().loss.__truediv__().detach.return_value = torch.tensor(1)
model().loss.detach.return_value = torch.tensor(1)
mock_tensor = mocker.MagicMock(name="tensor") mock_tensor = mocker.MagicMock(name="tensor")
batch = {"input": mock_tensor} batch = {"input": mock_tensor}
train_dataloader = [batch, batch, batch, batch, batch] train_dataloader = [batch, batch, batch, batch, batch]
...@@ -94,6 +95,7 @@ def test_gradient_accumulation( ...@@ -94,6 +95,7 @@ def test_gradient_accumulation(
def test_save_to_json(temp_output_dir, mocker): def test_save_to_json(temp_output_dir, mocker):
model = mocker.MagicMock(name="model") model = mocker.MagicMock(name="model")
model().loss.__truediv__().detach.return_value = torch.tensor(1) model().loss.__truediv__().detach.return_value = torch.tensor(1)
model().loss.detach.return_value = torch.tensor(1)
mock_tensor = mocker.MagicMock(name="tensor") mock_tensor = mocker.MagicMock(name="tensor")
batch = {"input": mock_tensor} batch = {"input": mock_tensor}
train_dataloader = [batch, batch, batch, batch, batch] train_dataloader = [batch, batch, batch, batch, batch]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment