Skip to content
Snippets Groups Projects
Commit d9ca0996 authored by Matthias Reso's avatar Matthias Reso
Browse files

Fix fixture in test_train_utils

parent d58dea23
No related branches found
No related tags found
No related merge requests found
......@@ -36,6 +36,7 @@ def test_gradient_accumulation(
model = mocker.MagicMock(name="model")
model().loss.__truediv__().detach.return_value = torch.tensor(1)
model().loss.detach.return_value = torch.tensor(1)
mock_tensor = mocker.MagicMock(name="tensor")
batch = {"input": mock_tensor}
train_dataloader = [batch, batch, batch, batch, batch]
......@@ -94,6 +95,7 @@ def test_gradient_accumulation(
def test_save_to_json(temp_output_dir, mocker):
model = mocker.MagicMock(name="model")
model().loss.__truediv__().detach.return_value = torch.tensor(1)
model().loss.detach.return_value = torch.tensor(1)
mock_tensor = mocker.MagicMock(name="tensor")
batch = {"input": mock_tensor}
train_dataloader = [batch, batch, batch, batch, batch]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment