diff --git a/mpu/transformer.py b/mpu/transformer.py index bad1a9ce47aa4dfdb167ff090cbdb3055ec44703..04c0736c4e8a63594872991897f367d9b1305c9b 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -52,9 +52,8 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, if log_attention_weights is not None: attention_scores += log_attention_weights - # if attention_mask.shape[-2] > 1: # if auto-regressive, skip - # attention_scores = torch.mul(attention_scores, attention_mask) - \ - # 10000.0 * (1.0 - attention_mask) + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) attention_probs = F.softmax(attention_scores, dim=-1)