Skip to content
Snippets Groups Projects
Commit 0ac85141 authored by duzx16's avatar duzx16
Browse files

Fix scaling_attention_score in standard_attention

parent 9bedc566
No related branches found
No related tags found
No related merge requests found
...@@ -49,7 +49,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, ...@@ -49,7 +49,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
# The implementation in the paper can be done very easily, if you really need it to train very deep transformers. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers.
if scaling_attention_score: if scaling_attention_score:
query_layer / math.sqrt(query_layer.shape[-1]) query_layer = query_layer / math.sqrt(query_layer.shape[-1])
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if log_attention_weights is not None: if log_attention_weights is not None:
attention_scores += log_attention_weights attention_scores += log_attention_weights
...@@ -469,6 +469,7 @@ class BaseTransformer(torch.nn.Module): ...@@ -469,6 +469,7 @@ class BaseTransformer(torch.nn.Module):
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None, def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None,
output_hidden_states=False, **kw_args): output_hidden_states=False, **kw_args):
breakpoint()
# sanity check # sanity check
assert len(input_ids.shape) == 2 assert len(input_ids.shape) == 2
batch_size, query_length = input_ids.shape batch_size, query_length = input_ids.shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment