From 51e1a51395330f2e84402b7ae789f40e260d9ae8 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Mon, 13 Dec 2021 16:36:22 +0800
Subject: [PATCH] Add non_conflict for T5 attention fix position bias for
 memory in T5

---
 SwissArmyTransformer/model/t5_model.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index 31f38e0..d1a5095 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -3,6 +3,7 @@ import torch
 import torch.nn.functional as F
 from .mixins import BaseMixin
 from .encoder_decoder_model import EncoderDecoderModel
+from .base_model import non_conflict
 from SwissArmyTransformer.mpu import get_model_parallel_world_size
 from SwissArmyTransformer.mpu.transformer import standard_attention, SelfAttention, CrossAttention, MLP
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
@@ -111,6 +112,7 @@ class T5AttentionMixin(BaseMixin):
         values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
         return values
 
+    @non_conflict
     def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
                      cross_attention=False, **kw_args):
         log_attention_weights = None
@@ -118,7 +120,8 @@ class T5AttentionMixin(BaseMixin):
             if position_bias is None:
                 seq_length = q.size(2)
                 key_length = k.size(2)
-                position_bias = self.compute_bias(seq_length, key_length)
+                position_bias = self.compute_bias(key_length, key_length)
+                position_bias = position_bias[:, :, -seq_length:, :]
             kw_args['output_cross_layer']['position_bias'] = position_bias
             log_attention_weights = position_bias
         return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,
-- 
GitLab