From 3e304e54f326e609575d71cd8641011e9edaf08e Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Tue, 19 Oct 2021 09:05:12 +0000
Subject: [PATCH] fix loss_checker bug & slice

---
 model/mixins.py                | 2 +-
 training/deepspeed_training.py | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/model/mixins.py b/model/mixins.py
index 78befb8..125f8b0 100644
--- a/model/mixins.py
+++ b/model/mixins.py
@@ -26,7 +26,7 @@ class BaseMixin(torch.nn.Module):
 
 class PositionEmbeddingMixin(BaseMixin):
     def __init__(self, additional_sequence_length, hidden_size, 
-                init_method_std=0.02, reinit_slice=(-1024, None)
+                init_method_std=0.02, reinit_slice=slice(-1024, None)
         ):
         super(PositionEmbeddingMixin, self).__init__()
         self.reinit_slice = reinit_slice
diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py
index 90190a6..094d49b 100644
--- a/training/deepspeed_training.py
+++ b/training/deepspeed_training.py
@@ -337,10 +337,10 @@ def train_step(data_iterator, model, optimizer, lr_scheduler,
         # and all reduce metrics by the way
         loss_checker = lm_loss.detach()
         for name in metrics:
-            metrics[name] = metrics[name].detach()
+            metrics[name] = metrics[name].detach().clone()
             torch.distributed.all_reduce(metrics[name].data)
             metrics[name].data /= args.world_size
-            loss_checker += metrics[name]
+            loss_checker = loss_checker + metrics[name]
         if loss_checker.isnan().any() or loss_checker.isinf().any():
             print('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
             return lm_loss.detach(), 1, metrics
-- 
GitLab