From 0ac851414817c691390aff96dcd62d989bde00c4 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Sat, 4 Dec 2021 16:30:33 +0800 Subject: [PATCH] Fix scaling_attention_score in standard_attention --- SwissArmyTransformer/mpu/transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index d581455..957d88f 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -49,7 +49,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. if scaling_attention_score: - query_layer / math.sqrt(query_layer.shape[-1]) + query_layer = query_layer / math.sqrt(query_layer.shape[-1]) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if log_attention_weights is not None: attention_scores += log_attention_weights @@ -469,6 +469,7 @@ class BaseTransformer(torch.nn.Module): def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None, output_hidden_states=False, **kw_args): + breakpoint() # sanity check assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape -- GitLab