From 51e1a51395330f2e84402b7ae789f40e260d9ae8 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Mon, 13 Dec 2021 16:36:22 +0800 Subject: [PATCH] Add non_conflict for T5 attention fix position bias for memory in T5 --- SwissArmyTransformer/model/t5_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index 31f38e0..d1a5095 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F from .mixins import BaseMixin from .encoder_decoder_model import EncoderDecoderModel +from .base_model import non_conflict from SwissArmyTransformer.mpu import get_model_parallel_world_size from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region @@ -111,6 +112,7 @@ class T5AttentionMixin(BaseMixin): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @non_conflict def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention, cross_attention=False, **kw_args): log_attention_weights = None @@ -118,7 +120,8 @@ class T5AttentionMixin(BaseMixin): if position_bias is None: seq_length = q.size(2) key_length = k.size(2) - position_bias = self.compute_bias(seq_length, key_length) + position_bias = self.compute_bias(key_length, key_length) + position_bias = position_bias[:, :, -seq_length:, :] kw_args['output_cross_layer']['position_bias'] = position_bias log_attention_weights = position_bias return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias, -- GitLab