diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py index dbd8fe5fbcdc9871a962ff43920a2797e8d73bf0..cdff797d197da83d71bd968e3bb02b71f7fa85b5 100755 --- a/SwissArmyTransformer/mpu/transformer.py +++ b/SwissArmyTransformer/mpu/transformer.py @@ -39,7 +39,7 @@ class LayerNorm(FusedLayerNorm): if not self.pb_relax: return super().forward(x) return super().forward(x / (x.abs().max().detach()/8)) - + def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, log_attention_weights=None): # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. @@ -51,7 +51,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask, ) if log_attention_weights is not None: attention_scores += log_attention_weights - + # if attention_mask.shape[-2] > 1: # if auto-regressive, skip attention_scores = torch.mul(attention_scores, attention_mask) - \ 10000.0 * (1.0 - attention_mask) @@ -84,7 +84,7 @@ class SelfAttention(torch.nn.Module): # Strided linear layer. self.query_key_value = ColumnParallelLinear( - hidden_size, + hidden_size, 3*hidden_size, stride=3, gather_output=False, @@ -125,16 +125,16 @@ class SelfAttention(torch.nn.Module): query_layer = self._transpose_for_scores(mixed_query_layer) key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) - + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) output = self.dense(context_layer) - + if self.training: output = self.output_dropout(output) - + return output, None @@ -169,7 +169,7 @@ class MLP(torch.nn.Module): intermediate_parallel = self.dense_h_to_4h(hidden_states) intermediate_parallel = gelu(intermediate_parallel) output = self.dense_4h_to_h(intermediate_parallel) - + if self.training: output = self.dropout(output) return output @@ -226,7 +226,7 @@ class BaseTransformerLayer(torch.nn.Module): output_layer_init_method=output_layer_init_method, hooks=hooks ) - + def forward(self, hidden_states, mask, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] @@ -277,20 +277,20 @@ class BaseTransformer(torch.nn.Module): hooks={} ): super(BaseTransformer, self).__init__() - + # recording parameters self.parallel_output = parallel_output self.checkpoint_activations = checkpoint_activations self.checkpoint_num_layers = checkpoint_num_layers self.max_sequence_length = max_sequence_length self.hooks = copy.copy(hooks) # hooks will be updated each forward - + # create embedding parameters self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) - + self.word_embeddings = VocabParallelEmbedding( vocab_size, hidden_size, init_method=unscaled_init_method(0.02)) - + self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size) torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) @@ -316,9 +316,10 @@ class BaseTransformer(torch.nn.Module): # Final layer norm before output. self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) - def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_args): + def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False, + **kw_args): # sanity check - assert len(input_ids.shape) == 2 + assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 @@ -332,16 +333,17 @@ class BaseTransformer(torch.nn.Module): hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args) else: # default hidden_states = self.word_embeddings(input_ids) - + if 'position_embedding_forward' in self.hooks: position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_args) else: assert len(position_ids.shape) <= 2 assert position_ids.shape[-1] == query_length - position_embeddings = self.position_embeddings(position_ids) + position_embeddings = self.position_embeddings(position_ids) hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) - + + hidden_states_outputs = [hidden_states] if output_hidden_states else [] # branch related embedding if branch_input is None and 'branch_embedding_forward' in self.hooks: branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args) @@ -352,10 +354,10 @@ class BaseTransformer(torch.nn.Module): def custom(start, end): def custom_forward(*inputs): layers_ = self.layers[start:end] - x_, mask = inputs[0], inputs[1] + x_, mask = inputs[0], inputs[1] if len(inputs) > 2: # have branch_input - branch_ = inputs[2] - output_per_layers_part = [] + branch_ = inputs[2] + output_per_layers_part = [] for i, layer in enumerate(layers_): if len(inputs) > 2: x_, branch_, output_this_layer = self.hooks['layer_forward']( @@ -370,7 +372,7 @@ class BaseTransformer(torch.nn.Module): output_per_layers_part.append(output_this_layer) return x_, output_per_layers_part return custom_forward - + l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers while l < num_layers: @@ -379,6 +381,8 @@ class BaseTransformer(torch.nn.Module): hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input) else: hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args) + if output_hidden_states: + hidden_states_outputs.append(hidden_states) output_per_layers.extend(output_per_layers_part) l += chunk_length else: @@ -390,26 +394,32 @@ class BaseTransformer(torch.nn.Module): hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args) else: hidden_states, output_this_layer = layer(*args, **kw_args) - output_per_layers.append(output_this_layer) + if output_hidden_states: + hidden_states_outputs.append(hidden_states) + output_per_layers.append(output_this_layer) # Final layer norm. logits = self.final_layernorm(hidden_states) - + if 'final_forward' in self.hooks: logits_parallel = self.hooks['final_forward'](logits, **kw_args) else: logits_parallel = copy_to_model_parallel_region(logits) logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight) - + # branch related embedding if branch_input is None and 'branch_final_forward' in self.hooks: branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args) if not self.parallel_output: logits_parallel = gather_from_model_parallel_region(logits_parallel) - + + outputs = [logits_parallel] if branch_input is not None: - return (logits_parallel, branch_input, *output_per_layers) - - return (logits_parallel, *output_per_layers) - + outputs.append(branch_input) + if output_hidden_states: + outputs.append(hidden_states_outputs) + outputs.extend(output_per_layers) + + return outputs +