diff --git a/mpu/transformer.py b/mpu/transformer.py
index bad1a9ce47aa4dfdb167ff090cbdb3055ec44703..04c0736c4e8a63594872991897f367d9b1305c9b 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -52,9 +52,8 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
     if log_attention_weights is not None:
         attention_scores += log_attention_weights
     
-    # if attention_mask.shape[-2] > 1: # if auto-regressive, skip
-    #     attention_scores = torch.mul(attention_scores, attention_mask) - \
-    #                 10000.0 * (1.0 - attention_mask)
+    attention_scores = torch.mul(attention_scores, attention_mask) - \
+                10000.0 * (1.0 - attention_mask)
 
     attention_probs = F.softmax(attention_scores, dim=-1)