From e9ae3132ff9f753cc3b27c8bf5d255667c1c5e2d Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Fri, 22 Oct 2021 17:23:31 +0800 Subject: [PATCH] Fix attention mask --- mpu/transformer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mpu/transformer.py b/mpu/transformer.py index bad1a9c..04c0736 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -52,9 +52,8 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, if log_attention_weights is not None: attention_scores += log_attention_weights - # if attention_mask.shape[-2] > 1: # if auto-regressive, skip - # attention_scores = torch.mul(attention_scores, attention_mask) - \ - # 10000.0 * (1.0 - attention_mask) + attention_scores = torch.mul(attention_scores, attention_mask) - \ + 10000.0 * (1.0 - attention_mask) attention_probs = F.softmax(attention_scores, dim=-1) -- GitLab