Skip to content
Snippets Groups Projects
transformer.py 16.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • Ming Ding's avatar
    Ming Ding committed
    # coding=utf-
    # rewritten, Copyright (c) 2021, Ming Ding.  All rights reserved.
    # Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    """Transformer."""
    
    import math
    import copy
    import torch
    import torch.nn.functional as F
    from apex.normalization.fused_layer_norm import FusedLayerNorm
    
    from .initialize import get_model_parallel_world_size
    from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding
    from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region
    
    import deepspeed
    
    from .random import checkpoint
    from .random import get_cuda_rng_tracker
    
    from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu
    from .utils import split_tensor_along_last_dim
    
    class LayerNorm(FusedLayerNorm):
        def __init__(self, *args, pb_relax=False, **kwargs):
            super().__init__(*args, **kwargs)
            self.pb_relax = pb_relax
        def forward(self, x):
            if not self.pb_relax:
                return super().forward(x)
            return super().forward(x / (x.abs().max().detach()/8))
            
    def standard_attention(query_layer, key_layer, value_layer, attention_mask,
                        attention_dropout=None, log_attention_weights=None):
        # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
        # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. 
    
        attention_scores = torch.matmul(
            query_layer / math.sqrt(query_layer.shape[-1]),
            key_layer.transpose(-1, -2)
        )
        
        if attention_mask.shape[-2] > 1: # if auto-regressive, skip
            attention_scores = torch.mul(attention_scores, attention_mask) - \
                        10000.0 * (1.0 - attention_mask)
        if log_attention_weights is not None:
            attention_scores += log_attention_weights
        
        attention_probs = F.softmax(attention_scores, dim=-1)
    
        if attention_dropout is not None:
            with get_cuda_rng_tracker().fork():
                attention_probs = attention_dropout(attention_probs)
    
        context_layer = torch.matmul(attention_probs, value_layer)
        return context_layer
    
    class SelfAttention(torch.nn.Module):
        def __init__(self, hidden_size, num_attention_heads,
                    attention_dropout_prob, output_dropout_prob,
                    init_method, layer_id, output_layer_init_method=None,
                    hooks={}):
            super(SelfAttention, self).__init__()
            # Set output layer initialization if not provided.
            if output_layer_init_method is None:
                output_layer_init_method = init_method
            self.hooks = hooks
            self.layer_id = layer_id
            # Per attention head and per partition values.
            world_size = get_model_parallel_world_size()
            self.hidden_size_per_partition = divide(hidden_size, world_size)
            self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
            self.num_attention_heads_per_partition = divide(num_attention_heads, world_size)
    
            # Strided linear layer.
            self.query_key_value = ColumnParallelLinear(
                hidden_size, 
                3*hidden_size,
                stride=3,
                gather_output=False,
                init_method=init_method
            )
            self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
    
            self.dense = RowParallelLinear(
                hidden_size,
                hidden_size,
                input_is_parallel=True,
                init_method=output_layer_init_method
            )
            self.output_dropout = torch.nn.Dropout(output_dropout_prob)
    
    
        def _transpose_for_scores(self, tensor):
            """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
            size [b, np, s, hn].
            """
            new_tensor_shape = tensor.size()[:-1] + \
                                (self.num_attention_heads_per_partition,
                                self.hidden_size_per_attention_head)
            tensor = tensor.view(*new_tensor_shape)
            return tensor.permute(0, 2, 1, 3)
    
        def forward(self, hidden_states, mask, *other_tensors):
            if 'attention_forward' in self.hooks:
    
    Ming Ding's avatar
    Ming Ding committed
                return self.hooks['attention_forward'](hidden_states, mask, *other_tensors,layer_id=self.layer_id)
    
    Ming Ding's avatar
    Ming Ding committed
            else:
                mixed_raw_layer = self.query_key_value(hidden_states)
                (mixed_query_layer,
                    mixed_key_layer,
                    mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
    
                dropout_fn = self.attention_dropout if self.training else None
    
                query_layer = self._transpose_for_scores(mixed_query_layer)
                key_layer = self._transpose_for_scores(mixed_key_layer)
                value_layer = self._transpose_for_scores(mixed_value_layer)
                
                context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
                context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
                new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
                context_layer = context_layer.view(*new_context_layer_shape)
                output = self.dense(context_layer)
                
                if self.training:
                    output = self.output_dropout(output)
                
                return output, None
    
    
    class MLP(torch.nn.Module):
        def __init__(self, hidden_size, output_dropout_prob, init_method,
                    output_layer_init_method=None, hooks={}):
            super(MLP, self).__init__()
            # Set output layer initialization if not provided.
            if output_layer_init_method is None:
                output_layer_init_method = init_method
            self.hooks = hooks
            # Project to 4h.
            self.dense_h_to_4h = ColumnParallelLinear(
                hidden_size,
                4*hidden_size,
                gather_output=False,
                init_method=init_method
            )
            # Project back to h.
            self.dense_4h_to_h = RowParallelLinear(
                4*hidden_size,
                hidden_size,
                input_is_parallel=True,
                init_method=output_layer_init_method
            )
            self.dropout = torch.nn.Dropout(output_dropout_prob)
    
        def forward(self, hidden_states, *other_tensors):
            if 'mlp_forward' in self.hooks:
                output = self.hooks['mlp_forward'](hidden_states, *other_tensors, layer_id=self.layer_id)
            else:
                intermediate_parallel = self.dense_h_to_4h(hidden_states)
                intermediate_parallel = gelu(intermediate_parallel)
                output = self.dense_4h_to_h(intermediate_parallel)
                
            if self.training:
                output = self.dropout(output)
            return output
    
    
    class BaseTransformerLayer(torch.nn.Module):
        """A single layer transformer for GPT2.
    
        We use the following notation:
            h: hidden size
            n: number of attention heads
            b: batch size
            s: sequence length
        Transformore layer takes input with size [b, s, h] and returns an
        output of the same size.
    
        Arguments:
            hidden_size: The hidden size of the self attention.
            num_attention_heads: number of attention head in the self
                                 attention.
            attention_dropout_prob: dropout probability of the attention
                                    score in self attention.
            output_dropout_prob: dropout probability for the outputs
                                 after self attention and final output.
            layernorm_epsilon: epsilon used in layernorm to avoid
                               division by zero.
            init_method: initialization method used for the weights. Note
                         that all biases are initialized to zero and
                         layernorm weight are initialized to one.
            output_layer_init_method: output layers (attention output and
                                      mlp output) initialization. If None,
                                      use `init_method`.
        """
        def __init__(
            self,
            hidden_size,
            num_attention_heads,
            attention_dropout_prob,
            output_dropout_prob,
            layernorm_epsilon,
            init_method,
            layer_id,
            output_layer_init_method=None,
            sandwich_ln=True,
            hooks={}
        ):
            super(BaseTransformerLayer, self).__init__()
            # Set output layer initialization if not provided.
            if output_layer_init_method is None:
                output_layer_init_method = init_method
            self.layer_id = layer_id
            self.hooks = hooks
    
            # Layernorm on the input data.
            self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
    
            # Self attention.
            self.attention = SelfAttention(
                hidden_size,
                num_attention_heads,
                attention_dropout_prob,
                output_dropout_prob,
                init_method,
                layer_id,
                output_layer_init_method=output_layer_init_method,
                hooks=hooks
            )
    
            # Layernorm on the input data.
            self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
            self.sandwich_ln = sandwich_ln
            if sandwich_ln:
                self.third_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
                self.fourth_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
    
            # MLP
            self.mlp = MLP(
                hidden_size,
                output_dropout_prob,
                init_method,
                output_layer_init_method=output_layer_init_method,
                hooks=hooks
            )
        
        def forward(self, hidden_states, mask, *other_tensors):
            '''
                hidden_states: [batch, seq_len, hidden_size]
                mask: [(1, 1), seq_len, seq_len]
            '''
    
            # Layer norm at the begining of the transformer layer.
            layernorm_output1 = self.input_layernorm(hidden_states)
            # Self attention.
            attention_output, output_this_layer = self.attention(layernorm_output1, mask, *other_tensors)
    
            # Third LayerNorm
            if self.sandwich_ln:
                attention_output = self.third_layernorm(attention_output)
    
            # Residual connection.
            layernorm_input = hidden_states + attention_output
            # Layer norm post the self attention.
            layernorm_output = self.post_attention_layernorm(layernorm_input)
            # MLP.
            mlp_output = self.mlp(layernorm_output)
    
            # Fourth LayerNorm
            if self.sandwich_ln:
                mlp_output = self.fourth_layernorm(mlp_output, *other_tensors)
    
            # Second residual connection.
            output = layernorm_input + mlp_output
    
            return output, output_this_layer # temporally, output_this_layer is only from attention
    
    class BaseTransformer(torch.nn.Module):
        def __init__(self,
                     num_layers,
                     vocab_size,
                     hidden_size,
                     num_attention_heads,
                     max_sequence_length,
                     embedding_dropout_prob,
                     attention_dropout_prob,
                     output_dropout_prob,
                     checkpoint_activations,
                     checkpoint_num_layers=1,
                     layernorm_epsilon=1.0e-5,
                     init_method_std=0.02,
                     sandwich_ln=True,
                     parallel_output=True,
                     hooks={}
                     ):
            super(BaseTransformer, self).__init__()
            if deepspeed.checkpointing.is_configured():
                global get_cuda_rng_tracker, checkpoint
                get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
                checkpoint = deepspeed.checkpointing.checkpoint
            
            # recording parameters
            self.parallel_output = parallel_output
            self.checkpoint_activations = checkpoint_activations
            self.checkpoint_num_layers = checkpoint_num_layers
            self.max_sequence_length = max_sequence_length
            self.hooks = copy.copy(hooks) # hooks will be updated each forward
            
            # create embedding parameters
            self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
            
            self.word_embeddings = VocabParallelEmbedding(
                vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
            
            self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
            torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
    
            # create all layers
            self.output_layer_init_method = scaled_init_method(init_method_std, num_layers)
            self.init_method = unscaled_init_method(init_method_std)
            def get_layer(layer_id):
                return BaseTransformerLayer(
                    hidden_size,
                    num_attention_heads,
                    attention_dropout_prob,
                    output_dropout_prob,
                    layernorm_epsilon,
                    self.init_method,
                    layer_id,
                    output_layer_init_method=self.output_layer_init_method,
                    sandwich_ln=sandwich_ln,
                    hooks=hooks
                    )
            self.layers = torch.nn.ModuleList(
                [get_layer(layer_id) for layer_id in range(num_layers)])
    
            # Final layer norm before output.
            self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
    
        def forward(self, input_ids, position_ids, attention_mask, *other_tensors):
            # sanity check 
            assert len(input_ids.shape) == 2 
            batch_size, query_length = input_ids.shape
            assert len(position_ids.shape) <= 2
            assert position_ids.shape[-1] == query_length
            assert len(attention_mask.shape) == 2 or \
                len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
    
            # embedding part
            if 'word_embedding_forward' in self.hooks:
                hidden_states = self.hooks['word_embedding_forward'](input_ids, *other_tensors)
            else: # default
                hidden_states = self.word_embeddings(input_ids)
                
            if 'position_embedding_forward' in self.hooks:
                position_embeddings = self.hooks['position_embedding_forward'](position_ids, *other_tensors)
            else:
                position_embeddings = self.position_embeddings(position_ids)    
            hidden_states = hidden_states + position_embeddings
            hidden_states = self.embedding_dropout(hidden_states)
    
            # define custom_forward for checkpointing
            output_per_layers = []
            if self.checkpoint_activations:
                def custom(start, end):
                    def custom_forward(*inputs):
                        layers_ = self.layers[start:end]
                        x_, mask, *other_tensors = inputs[0], inputs[1], inputs[2:]
                        for i, layer in enumerate(layers_):
                            x_, output_this_layer = layer(x_, mask, *other_tensors)
                            output_per_layers.append(output_this_layer)
                        return x_
                    return custom_forward
            
                l, num_layers = 0, len(self.layers)
                chunk_length = self.checkpoint_num_layers
                while l < num_layers:
                    args = [hidden_states, attention_mask, *other_tensors]
                    hidden_states = checkpoint(custom(l, l + chunk_length), *args)
                    l += chunk_length
            else:
                for i, layer in enumerate(self.layers):
                    args = [hidden_states, attention_mask, *other_tensors]
                    hidden_states, output_this_layer = layer(*args, *other_tensors)
                    output_per_layers.append(output_this_layer) 
    
            # Final layer norm.
            logits = self.final_layernorm(hidden_states)
            
            if 'final_forward' in self.hooks:
                logits_parallel = self.hooks['final_forward'](logits, *other_tensors)
            else:
                logits_parallel = copy_to_model_parallel_region(logits)
                logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
    
            if self.parallel_output:
                return (logits_parallel, *output_per_layers)
            return (gather_from_model_parallel_region(logits_parallel), *output_per_layers)