diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 957d88f93a6d3a3ce6314a3f540d52caedcea8fc..538d9651d3cde80d256c28288c51ff6e29fab664 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -527,6 +527,8 @@ class BaseTransformer(torch.nn.Module):
 
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
+            if self.training:
+                hidden_states.requires_grad_(True)
             while l < num_layers:
                 if branch_input is not None:
                     args = [hidden_states, attention_mask, branch_input]