From 00442c04d5f1b96bbbd29388b3b171ea29727115 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Sat, 4 Dec 2021 19:48:55 +0800
Subject: [PATCH] Add requires_grad at the beginning

---
 SwissArmyTransformer/mpu/transformer.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 957d88f..538d965 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -527,6 +527,8 @@ class BaseTransformer(torch.nn.Module):
 
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
+            if self.training:
+                hidden_states.requires_grad_(True)
             while l < num_layers:
                 if branch_input is not None:
                     args = [hidden_states, attention_mask, branch_input]
-- 
GitLab