diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index 31f38e0d1bc20229e97d4efdec33bb91f1d4c14d..d1a50951ea352babcd763a0178ed2de2013695dc 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -3,6 +3,7 @@ import torch
 import torch.nn.functional as F
 from .mixins import BaseMixin
 from .encoder_decoder_model import EncoderDecoderModel
+from .base_model import non_conflict
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
 from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
@@ -111,6 +112,7 @@ class T5AttentionMixin(BaseMixin):
         values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
         return values
 
+    @non_conflict
     def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
                      cross_attention=False, **kw_args):
         log_attention_weights = None
@@ -118,7 +120,8 @@ class T5AttentionMixin(BaseMixin):
             if position_bias is None:
                 seq_length = q.size(2)
                 key_length = k.size(2)
-                position_bias = self.compute_bias(seq_length, key_length)
+                position_bias = self.compute_bias(key_length, key_length)
+                position_bias = position_bias[:, :, -seq_length:, :]
             kw_args['output_cross_layer']['position_bias'] = position_bias
             log_attention_weights = position_bias
         return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,