From 4e0e9674b16fc0c460d8b3b3fabf3d309d111484 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Mon, 13 Dec 2021 14:58:28 +0800
Subject: [PATCH] Implement T5 Attention with attention_fn

---
 .../model/cached_autoregressive_model.py      | 28 ++++----
 SwissArmyTransformer/model/t5_model.py        | 69 ++++---------------
 SwissArmyTransformer/mpu/transformer.py       | 37 +++++-----
 3 files changed, 48 insertions(+), 86 deletions(-)

diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py
index ed25fe4..8caed66 100755
--- a/SwissArmyTransformer/model/cached_autoregressive_model.py
+++ b/SwissArmyTransformer/model/cached_autoregressive_model.py
@@ -21,20 +21,22 @@ class CachedAutoregressiveMixin(BaseMixin):
         super().__init__()     
            
     @non_conflict
-    def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args):
-        mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
-        b, nh, seq_len, hidden_size = k.shape
+    def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention,
+                     **kw_args):
+        if not cross_attention:
+            mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
+            b, nh, seq_len, hidden_size = k.shape
 
-        cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
-        kw_args['output_this_layer']['mem_kv'] = cache_kv
-        
-        if mem is not None: # the first time, mem is None
-            # might change batch_size
-            mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) 
-            memk, memv = mem[0], mem[1]
-            k = torch.cat((memk, k), dim=2)
-            v = torch.cat((memv, v), dim=2)
-        return old_impl(q, k, v, mask, dropout_fn, **kw_args)
+            cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
+            kw_args['output_this_layer']['mem_kv'] = cache_kv
+
+            if mem is not None: # the first time, mem is None
+                # might change batch_size
+                mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
+                memk, memv = mem[0], mem[1]
+                k = torch.cat((memk, k), dim=2)
+                v = torch.cat((memv, v), dim=2)
+        return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args)
 
 
 class CachedAutoregressiveModel(BaseModel):
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index fb7c7cb..e0d2dc6 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -114,63 +114,18 @@ class T5AttentionMixin(BaseMixin):
         values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
         return values
 
-    def attention_forward(self, hidden_states, mask, position_bias=None, *args, layer_id=None, mems=None, **kw_args):
-        attn_module = self.transformer.layers[layer_id].attention
-        seq_length = hidden_states.size(1)
-        memory_length = mems[layer_id].size(1) if mems else 0
-        mixed_raw_layer = attn_module.query_key_value(hidden_states)
-        (mixed_query_layer,
-         mixed_key_layer,
-         mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
-
-        dropout_fn = attn_module.attention_dropout if attn_module.training else None
-
-        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
-        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
-        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
-
-        if position_bias is None:
-            position_bias = self.compute_bias(seq_length, memory_length + seq_length)
-        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn,
-                                           log_attention_weights=position_bias, scaling_attention_score=False)
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
-        context_layer = context_layer.view(*new_context_layer_shape)
-        output = attn_module.dense(context_layer)
-
-        if attn_module.training:
-            output = attn_module.output_dropout(output)
-
-        kw_args['output_cross_layer']['position_bias'] = position_bias
-
-        return output
-
-    def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, layer_id=None, *args,
-                                **kw_args):
-        attn_module = self.transformer.layers[layer_id].cross_attention
-        mixed_query_layer = attn_module.query(hidden_states)
-        mixed_x_layer = attn_module.key_value(encoder_outputs)
-        (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
-
-        dropout_fn = attn_module.attention_dropout if attn_module.training else None
-        # Reshape and transpose [b, np, s, hn]
-        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
-        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
-        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
-
-        context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
-                                           scaling_attention_score=False)
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
-        # [b, s, hp]
-        context_layer = context_layer.view(*new_context_layer_shape)
-
-        # Output. [b, s, h]
-        output = attn_module.dense(context_layer)
-        if attn_module.training:
-            output = attn_module.output_dropout(output)
-
-        return output
+    def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention,
+                     cross_attention=False, **kw_args):
+        log_attention_weights = None
+        if not cross_attention:
+            if position_bias is None:
+                seq_length = q.size(2)
+                key_length = k.size(2)
+                position_bias = self.compute_bias(seq_length, key_length)
+            kw_args['output_cross_layer']['position_bias'] = position_bias
+            log_attention_weights = position_bias
+        return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias,
+                        log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args)
 
 
 class T5DecoderFinalMixin(BaseMixin):
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 3b4ae73..2890569 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -218,6 +218,10 @@ class CrossAttention(torch.nn.Module):
         if 'cross_attention_forward' in self.hooks:
             return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs, **kw_args)
         else:
+            attention_fn = standard_attention
+            if 'attention_fn' in self.hooks:
+                attention_fn = self.hooks['attention_fn']
+
             mixed_query_layer = self.query(hidden_states)
             mixed_x_layer = self.key_value(encoder_outputs)
             (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
@@ -228,7 +232,8 @@ class CrossAttention(torch.nn.Module):
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
 
-            context_layer = standard_attention(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn)
+            context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
+                                         cross_attention=True, **kw_args)
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
             # [b, s, hp]
@@ -504,7 +509,7 @@ class BaseTransformer(torch.nn.Module):
             )  # None means full attention
         assert len(attention_mask.shape) == 2 or \
                len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
-    
+
         # embedding part
         if 'word_embedding_forward' in self.hooks:
             hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
@@ -526,7 +531,7 @@ class BaseTransformer(torch.nn.Module):
             output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
         else:
             output_cross_layer = {}
-            
+
         output_per_layers = []
         if self.checkpoint_activations:
             # define custom_forward for checkpointing
@@ -534,7 +539,7 @@ class BaseTransformer(torch.nn.Module):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
                     x_, mask = inputs[0], inputs[1]
-                    
+
                     # recover kw_args and output_cross_layer
                     flat_inputs = inputs[2:]
                     kw_args, output_cross_layer = {}, {}
@@ -543,19 +548,19 @@ class BaseTransformer(torch.nn.Module):
                     for k, idx in cross_layer_index.items():
                         output_cross_layer[k] = flat_inputs[idx]
                     # -----------------
-                    
+
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
                         if 'layer_forward' in self.hooks:
                             layer_ret = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, 
-                                **kw_args, **output_cross_layer, 
+                                x_, mask, layer_id=layer.layer_id,
+                                **kw_args, **output_cross_layer,
                                 output_this_layer={}, output_cross_layer={}
                             )
                         else:
                             layer_ret = layer(
-                                x_, mask, layer_id=layer.layer_id, 
-                                **kw_args, **output_cross_layer, 
+                                x_, mask, layer_id=layer.layer_id,
+                                **kw_args, **output_cross_layer,
                                 output_this_layer={}, output_cross_layer={}
                             )
                         if torch.is_tensor(layer_ret): # only hidden_states
@@ -563,7 +568,7 @@ class BaseTransformer(torch.nn.Module):
                         elif len(layer_ret) == 2: # hidden_states & output_this_layer
                             x_, output_this_layer = layer_ret
                             output_cross_layer = {}
-                        elif len(layer_ret) == 3: 
+                        elif len(layer_ret) == 3:
                             x_, output_this_layer, output_cross_layer = layer_ret
                         assert isinstance(output_this_layer, dict)
                         assert isinstance(output_cross_layer, dict)
@@ -590,7 +595,7 @@ class BaseTransformer(torch.nn.Module):
             # To save memory when only finetuning the final layers, don't use checkpointing.
             if self.training:
                 hidden_states.requires_grad_(True)
-            
+
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
             output_this_layer = []
@@ -625,20 +630,20 @@ class BaseTransformer(torch.nn.Module):
                 args = [hidden_states, attention_mask]
 
                 if 'layer_forward' in self.hooks: # customized layer_forward
-                    layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), 
-                        **kw_args, 
-                        **output_cross_layer, 
+                    layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
+                        **kw_args,
+                        **output_cross_layer,
                         output_this_layer={}, output_cross_layer={}
                     )
                 else:
-                    layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, 
+                    layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer,
                         output_this_layer={}, output_cross_layer={})
                 if torch.is_tensor(layer_ret): # only hidden_states
                     hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
                 elif len(layer_ret) == 2: # hidden_states & output_this_layer
                     hidden_states, output_this_layer = layer_ret
                     output_cross_layer = {}
-                elif len(layer_ret) == 3: 
+                elif len(layer_ret) == 3:
                     hidden_states, output_this_layer, output_cross_layer = layer_ret
 
                 if output_hidden_states:
-- 
GitLab