diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index 31f38e0d1bc20229e97d4efdec33bb91f1d4c14d..d1a50951ea352babcd763a0178ed2de2013695dc 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F from .mixins import BaseMixin from .encoder_decoder_model import EncoderDecoderModel +from .base_model import non_conflict from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region @@ -111,6 +112,7 @@ class T5AttentionMixin(BaseMixin): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @non_conflict def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention, cross_attention=False, **kw_args): log_attention_weights = None @@ -118,7 +120,8 @@ class T5AttentionMixin(BaseMixin): if position_bias is None: seq_length = q.size(2) key_length = k.size(2) - position_bias = self.compute_bias(seq_length, key_length) + position_bias = self.compute_bias(key_length, key_length) + position_bias = position_bias[:, :, -seq_length:, :] kw_args['output_cross_layer']['position_bias'] = position_bias log_attention_weights = position_bias return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,