From 6a2fa4083ccafc260943cd8338da6bc6fe40b129 Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Sun, 5 Dec 2021 07:51:24 +0000
Subject: [PATCH] tidy up enc-dec model

---
 SwissArmyTransformer/model/base_model.py      |  3 +-
 .../model/encoder_decoder_model.py            | 93 ++++---------------
 SwissArmyTransformer/model/t5_model.py        |  5 +-
 SwissArmyTransformer/mpu/transformer.py       | 70 +++++++-------
 4 files changed, 59 insertions(+), 112 deletions(-)

diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 8dfb989..0b593bb 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -13,8 +13,7 @@ import math
 import random
 import torch
 
-from SwissArmyTransformer.mpu import BaseTransformer, LayerNorm
-
+from SwissArmyTransformer.mpu import BaseTransformer
 
 class BaseMixin(torch.nn.Module):
     def __init__(self):
diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index e7395b7..0c6336c 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -17,60 +17,6 @@ from .base_model import BaseModel, BaseMixin
 from SwissArmyTransformer.mpu.mappings import copy_to_model_parallel_region
 
 
-def get_extended_attention_mask(attention_mask, input_shape, device, dtype=torch.float32, is_decoder=False):
-    """
-    Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
-
-    Arguments:
-        attention_mask (:obj:`torch.Tensor`):
-            Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
-        input_shape (:obj:`Tuple[int]`):
-            The shape of the input to the model.
-        device: (:obj:`torch.device`):
-            The device of the input to the model.
-        dtype:
-        is_decoder:
-
-    Returns:
-        :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
-    """
-    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
-    # ourselves in which case we just need to make it broadcastable to all heads.
-    if attention_mask is None or attention_mask.dim() == 2:
-        batch_size, seq_length = input_shape
-        # Provided a padding mask of dimensions [batch_size, seq_length]
-        # - if the model is a decoder, apply a causal mask in addition to the padding mask
-        # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
-        if is_decoder:
-            seq_ids = torch.arange(seq_length, device=device)
-            causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
-            # in case past_key_values are used we need to add a prefix ones mask to the causal mask
-            # causal and attention masks must have same type with pytorch version < 1.3
-            causal_mask = causal_mask.to(dtype)
-            if attention_mask is not None:
-                if causal_mask.shape[1] < attention_mask.shape[1]:
-                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
-                    causal_mask = torch.cat(
-                        [torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
-                         causal_mask], axis=-1)
-
-                extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
-            else:
-                extended_attention_mask = causal_mask[:, None, :, :]
-        else:
-            if attention_mask is None:
-                extended_attention_mask = torch.ones(1, 1, 1, seq_length, device=device, dtype=dtype)
-            else:
-                extended_attention_mask = attention_mask[:, None, None, :]
-    elif attention_mask.dim() == 3:
-        extended_attention_mask = attention_mask[:, None, :, :]
-    else:
-        raise ValueError(
-            f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
-        )
-    return extended_attention_mask
-
-
 class EncoderFinalMixin(BaseMixin):
     def final_forward(self, logits, **kwargs):
         logits = copy_to_model_parallel_region(logits)
@@ -78,7 +24,7 @@ class EncoderFinalMixin(BaseMixin):
 
 
 class EncoderDecoderModel(torch.nn.Module):
-    def __init__(self, args, encoder=None, decoder=None, parallel_output=False, **kwargs):
+    def __init__(self, args, encoder=None, decoder=None, tie_word_embeddings=True, parallel_output=False, **kwargs):
         super(EncoderDecoderModel, self).__init__()
         if encoder is not None:
             assert isinstance(encoder, BaseModel)
@@ -86,6 +32,7 @@ class EncoderDecoderModel(torch.nn.Module):
         else:
             self.encoder = BaseModel(args, **kwargs)
         self.encoder.add_mixin("final", EncoderFinalMixin())
+        
         if decoder is not None:
             assert isinstance(decoder, BaseModel)
             self.decoder = decoder
@@ -100,6 +47,10 @@ class EncoderDecoderModel(torch.nn.Module):
                     setattr(dec_args, name, dec_attr)
             self.decoder = BaseModel(args, is_decoder=True, parallel_output=parallel_output, **kwargs)
 
+        self.tie_word_embeddings = tie_word_embeddings
+        if tie_word_embeddings:
+            self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
+
     def reinit(self):
         self.encoder.reinit()
         self.decoder.reinit()
@@ -108,25 +59,19 @@ class EncoderDecoderModel(torch.nn.Module):
         self.encoder.disable_untrainable_params()
         self.decoder.disable_untrainable_params()
 
-    def forward(self, input_ids=None, input_position_ids=None, attention_mask=None, decoder_input_ids=None,
-                decoder_position_ids=None, decoder_attention_mask=None, encoder_outputs=None,
-                **kw_args):
-        dtype = self.encoder.transformer.word_embeddings.weight.dtype
-        if encoder_outputs is None:
-            batch_size, encoder_seq_length = input_ids.size()[:2]
-        else:
-            batch_size, encoder_seq_length = encoder_outputs.size()[:2]
-        encoder_attention_mask = get_extended_attention_mask(attention_mask, (batch_size, encoder_seq_length),
-                                                             device=input_ids.device, dtype=dtype)
-        decoder_seq_length = decoder_input_ids.size(1)
-        if encoder_outputs is None:
-            encoder_outputs, *_dumps = self.encoder(input_ids, input_position_ids, encoder_attention_mask, **kw_args)
-        decoder_attention_mask = get_extended_attention_mask(decoder_attention_mask, (batch_size, decoder_seq_length),
-                                                             device=input_ids.device, dtype=dtype, is_decoder=True)
-        decoder_outputs, *decoder_mems = self.decoder(decoder_input_ids, decoder_position_ids, decoder_attention_mask,
-                                                      encoder_outputs=encoder_outputs,
-                                                      cross_attention_mask=encoder_attention_mask, **kw_args)
-        return encoder_outputs, decoder_outputs, *decoder_mems
+    def encode(self, input_ids, position_ids, attention_mask=None, **kw_args):
+        encoder_outputs, *_dumps = self.encoder(input_ids, position_ids, attention_mask, **kw_args)
+        return encoder_outputs
+    
+    def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args):
+        # If no context, please explicitly pass ``encoder_outputs=None''
+        return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
+    
+    def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids,dec_attention_mask, *, enc_attention_mask=None, cross_attention_mask=None, **kw_args):
+        # Please use self.decoder for auto-regressive generation.
+        encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args)
+        decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
+        return encoder_outputs, decoder_outputs, *mems
 
     @classmethod
     def add_model_specific_args(cls, parser):
diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index bf4f2c9..944a133 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -180,8 +180,8 @@ class T5DecoderFinalMixin(BaseMixin):
 
 class T5Model(EncoderDecoderModel):
     def __init__(self, args, **kwargs):
-        super().__init__(args, **kwargs, use_bias=False, layernorm=T5LayerNorm,
-                         activation_func=torch.nn.functional.relu)
+        super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, 
+        layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu)
         self.encoder.add_mixin(
             "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads)
         )
@@ -200,7 +200,6 @@ class T5Model(EncoderDecoderModel):
             "t5-final", T5DecoderFinalMixin(args.hidden_size)
         )
         del self.decoder.transformer.position_embeddings
-        self.decoder.transformer.word_embeddings = self.encoder.transformer.word_embeddings
 
     @classmethod
     def add_model_specific_args(cls, parser):
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 538d965..e9f7ab1 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -198,7 +198,7 @@ class CrossAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *args, **kw_args):
+    def forward(self, hidden_states, cross_attention_mask, encoder_outputs, *, **kw_args):
         # hidden_states: [b, s, h]
         if 'cross_attention_forward' in self.hooks:
             return self.hooks['cross_attention_forward'](hidden_states, cross_attention_mask, encoder_outputs,
@@ -353,7 +353,7 @@ class BaseTransformerLayer(torch.nn.Module):
             hooks=hooks
         )
 
-    def forward(self, hidden_states, mask, encoder_outputs=None, cross_attention_mask=None, **kw_args):
+    def forward(self, hidden_states, mask, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -373,13 +373,16 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
 
-        if self.is_decoder and encoder_outputs is not None:
-            # Cross attention
-            attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
-            # Residual connection.
-            layernorm_input = layernorm_input + attention_output
-            # Layer norm post the cross attention
-            layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
+        if self.is_decoder:
+            encoder_outputs = kw_args['encoder_outputs']
+            if encoder_outputs is not None:
+                cross_attention_mask=kw_args['cross_attention_mask']
+                # Cross attention
+                attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
+                # Residual connection.
+                layernorm_input = layernorm_input + attention_output
+                # Layer norm post the cross attention
+                layernorm_output = self.post_cross_attention_layernorm(layernorm_input)
 
         # MLP.
         mlp_output = self.mlp(layernorm_output, **kw_args)
@@ -467,12 +470,15 @@ class BaseTransformer(torch.nn.Module):
         # Final layer norm before output.
         self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
-    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, encoder_outputs=None,
+    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, 
                 output_hidden_states=False, **kw_args):
-        breakpoint()
         # sanity check
         assert len(input_ids.shape) == 2
         batch_size, query_length = input_ids.shape
+        if attention_mask is None:
+            attention_mask = torch.ones(1, 1, device=input_ids.device).type_as(
+                next(self.parameters())
+            ) # None means full attention
         assert len(attention_mask.shape) == 2 or \
                len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
         assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
@@ -507,34 +513,37 @@ class BaseTransformer(torch.nn.Module):
             def custom(start, end):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
-                    x_, mask, encoder_outputs_ = inputs[0], inputs[1], inputs[2]
+                    x_, mask = inputs[0], inputs[1]
+                    if len(inputs) > 2: # have branch_input
+                        branch_ = inputs[2]
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
-                        if branch_input is not None:
-                            x_, encoder_outputs_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, branch_input=encoder_outputs_, **kw_args
+                        if len(inputs) > 2:
+                            x_, branch_, output_this_layer = self.hooks['layer_forward'](
+                                x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
                             )
                         elif 'layer_forward' in self.hooks:
                             x_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, encoder_outputs_, layer_id=layer.layer_id, **kw_args
+                                x_, mask, layer_id=layer.layer_id, **kw_args
                             )
                         else:
-                            x_, output_this_layer = layer(x_, mask, encoder_outputs_, **kw_args)
+                            x_, output_this_layer = layer(x_, mask, **kw_args)
                         output_per_layers_part.append(output_this_layer)
                     return x_, output_per_layers_part
-
                 return custom_forward
 
-            l, num_layers = 0, len(self.layers)
-            chunk_length = self.checkpoint_num_layers
+            # prevent to lose requires_grad in checkpointing.
+            # To save memory when only finetuning the final layers, don't use checkpointing.
             if self.training:
                 hidden_states.requires_grad_(True)
+
+            l, num_layers = 0, len(self.layers)
+            chunk_length = self.checkpoint_num_layers
             while l < num_layers:
+                args = [hidden_states, attention_mask]
                 if branch_input is not None:
-                    args = [hidden_states, attention_mask, branch_input]
-                    hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
+                    hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
                 else:
-                    args = [hidden_states, attention_mask, encoder_outputs]
                     hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
                 if output_hidden_states:
                     hidden_states_outputs.append(hidden_states)
@@ -542,16 +551,11 @@ class BaseTransformer(torch.nn.Module):
                 l += chunk_length
         else:
             for i, layer in enumerate(self.layers):
-                args = [hidden_states, attention_mask, encoder_outputs]
-                if branch_input is not None:  # customized layer_forward with branch_input
-                    hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args,
-                                                                                                 layer_id=torch.tensor(
-                                                                                                     i),
-                                                                                                 branch_input=branch_input,
-                                                                                                 **kw_args)
-                elif 'layer_forward' in self.hooks:  # customized layer_forward
-                    hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i),
-                                                                                   **kw_args)
+                args = [hidden_states, attention_mask]
+                if branch_input is not None: # customized layer_forward with branch_input
+                    hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
+                elif 'layer_forward' in self.hooks: # customized layer_forward
+                    hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
                 else:
                     hidden_states, output_this_layer = layer(*args, **kw_args)
                 if output_hidden_states:
-- 
GitLab