From 0ac851414817c691390aff96dcd62d989bde00c4 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Sat, 4 Dec 2021 16:30:33 +0800
Subject: [PATCH] Fix scaling_attention_score in standard_attention

---
 SwissArmyTransformer/mpu/transformer.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index d581455..957d88f 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -49,7 +49,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
     # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
 
     if scaling_attention_score:
-        query_layer / math.sqrt(query_layer.shape[-1])
+        query_layer = query_layer / math.sqrt(query_layer.shape[-1])
     attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
     if log_attention_weights is not None:
         attention_scores += log_attention_weights
@@ -469,6 +469,7 @@ class BaseTransformer(torch.nn.Module):
 
     def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None,
                 output_hidden_states=False, **kw_args):
+        breakpoint()
         # sanity check
         assert len(input_ids.shape) == 2
         batch_size, query_length = input_ids.shape
-- 
GitLab