diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index 957d88f93a6d3a3ce6314a3f540d52caedcea8fc..538d9651d3cde80d256c28288c51ff6e29fab664 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -527,6 +527,8 @@ class BaseTransformer(torch.nn.Module): l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers + if self.training: + hidden_states.requires_grad_(True) while l < num_layers: if branch_input is not None: args = [hidden_states, attention_mask, branch_input]