From d9ca099613068fa04c77ea435d42e7c666961f18 Mon Sep 17 00:00:00 2001
From: Matthias Reso <13337103+mreso@users.noreply.github.com>
Date: Tue, 15 Oct 2024 13:46:13 -0700
Subject: [PATCH] Fix fixture in test_train_utils

---
 src/tests/test_train_utils.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/src/tests/test_train_utils.py b/src/tests/test_train_utils.py
index e8a40ffe..66e3e9f0 100644
--- a/src/tests/test_train_utils.py
+++ b/src/tests/test_train_utils.py
@@ -36,6 +36,7 @@ def test_gradient_accumulation(
 
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
@@ -94,6 +95,7 @@ def test_gradient_accumulation(
 def test_save_to_json(temp_output_dir, mocker):
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    model().loss.detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
     batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
-- 
GitLab