From e9ae3132ff9f753cc3b27c8bf5d255667c1c5e2d Mon Sep 17 00:00:00 2001
From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn>
Date: Fri, 22 Oct 2021 17:23:31 +0800
Subject: [PATCH] Fix attention mask

---
 mpu/transformer.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mpu/transformer.py b/mpu/transformer.py
index bad1a9c..04c0736 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -52,9 +52,8 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
     if log_attention_weights is not None:
         attention_scores += log_attention_weights
     
-    # if attention_mask.shape[-2] > 1: # if auto-regressive, skip
-    #     attention_scores = torch.mul(attention_scores, attention_mask) - \
-    #                 10000.0 * (1.0 - attention_mask)
+    attention_scores = torch.mul(attention_scores, attention_mask) - \
+                10000.0 * (1.0 - attention_mask)
 
     attention_probs = F.softmax(attention_scores, dim=-1)
 
-- 
GitLab