diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index dbd8fe5fbcdc9871a962ff43920a2797e8d73bf0..cdff797d197da83d71bd968e3bb02b71f7fa85b5 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -39,7 +39,7 @@ class LayerNorm(FusedLayerNorm):
         if not self.pb_relax:
             return super().forward(x)
         return super().forward(x / (x.abs().max().detach()/8))
-        
+
 def standard_attention(query_layer, key_layer, value_layer, attention_mask,
                     attention_dropout=None, log_attention_weights=None):
     # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. 
@@ -51,7 +51,7 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
     )
     if log_attention_weights is not None:
         attention_scores += log_attention_weights
-    
+
     # if attention_mask.shape[-2] > 1: # if auto-regressive, skip
     attention_scores = torch.mul(attention_scores, attention_mask) - \
                 10000.0 * (1.0 - attention_mask)
@@ -84,7 +84,7 @@ class SelfAttention(torch.nn.Module):
 
         # Strided linear layer.
         self.query_key_value = ColumnParallelLinear(
-            hidden_size, 
+            hidden_size,
             3*hidden_size,
             stride=3,
             gather_output=False,
@@ -125,16 +125,16 @@ class SelfAttention(torch.nn.Module):
             query_layer = self._transpose_for_scores(mixed_query_layer)
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
-            
+
             context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn)
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
             context_layer = context_layer.view(*new_context_layer_shape)
             output = self.dense(context_layer)
-            
+
             if self.training:
                 output = self.output_dropout(output)
-            
+
             return output, None
 
 
@@ -169,7 +169,7 @@ class MLP(torch.nn.Module):
             intermediate_parallel = self.dense_h_to_4h(hidden_states)
             intermediate_parallel = gelu(intermediate_parallel)
             output = self.dense_4h_to_h(intermediate_parallel)
-            
+
         if self.training:
             output = self.dropout(output)
         return output
@@ -226,7 +226,7 @@ class BaseTransformerLayer(torch.nn.Module):
             output_layer_init_method=output_layer_init_method,
             hooks=hooks
         )
-    
+
     def forward(self, hidden_states, mask, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
@@ -277,20 +277,20 @@ class BaseTransformer(torch.nn.Module):
                  hooks={}
                  ):
         super(BaseTransformer, self).__init__()
-        
+
         # recording parameters
         self.parallel_output = parallel_output
         self.checkpoint_activations = checkpoint_activations
         self.checkpoint_num_layers = checkpoint_num_layers
         self.max_sequence_length = max_sequence_length
         self.hooks = copy.copy(hooks) # hooks will be updated each forward
-        
+
         # create embedding parameters
         self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
-        
+
         self.word_embeddings = VocabParallelEmbedding(
             vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
-        
+
         self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
 
@@ -316,9 +316,10 @@ class BaseTransformer(torch.nn.Module):
         # Final layer norm before output.
         self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
 
-    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_args):
+    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, output_hidden_states=False,
+                **kw_args):
         # sanity check 
-        assert len(input_ids.shape) == 2 
+        assert len(input_ids.shape) == 2
         batch_size, query_length = input_ids.shape
         assert len(attention_mask.shape) == 2 or \
             len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
@@ -332,16 +333,17 @@ class BaseTransformer(torch.nn.Module):
             hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
         else: # default
             hidden_states = self.word_embeddings(input_ids)
-            
+
         if 'position_embedding_forward' in self.hooks:
             position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_args)
         else:
             assert len(position_ids.shape) <= 2
             assert position_ids.shape[-1] == query_length
-            position_embeddings = self.position_embeddings(position_ids)    
+            position_embeddings = self.position_embeddings(position_ids)
         hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
-        
+
+        hidden_states_outputs = [hidden_states] if output_hidden_states else []
         # branch related embedding
         if branch_input is None and 'branch_embedding_forward' in self.hooks:
             branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args)
@@ -352,10 +354,10 @@ class BaseTransformer(torch.nn.Module):
             def custom(start, end):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
-                    x_, mask = inputs[0], inputs[1]    
+                    x_, mask = inputs[0], inputs[1]
                     if len(inputs) > 2: # have branch_input
-                        branch_ = inputs[2]     
-                    output_per_layers_part = []               
+                        branch_ = inputs[2]
+                    output_per_layers_part = []
                     for i, layer in enumerate(layers_):
                         if len(inputs) > 2:
                             x_, branch_, output_this_layer = self.hooks['layer_forward'](
@@ -370,7 +372,7 @@ class BaseTransformer(torch.nn.Module):
                         output_per_layers_part.append(output_this_layer)
                     return x_, output_per_layers_part
                 return custom_forward
-        
+
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
             while l < num_layers:
@@ -379,6 +381,8 @@ class BaseTransformer(torch.nn.Module):
                     hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
                 else:
                     hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
+                if output_hidden_states:
+                    hidden_states_outputs.append(hidden_states)
                 output_per_layers.extend(output_per_layers_part)
                 l += chunk_length
         else:
@@ -390,26 +394,32 @@ class BaseTransformer(torch.nn.Module):
                     hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
                 else:
                     hidden_states, output_this_layer = layer(*args, **kw_args)
-                output_per_layers.append(output_this_layer) 
+                if output_hidden_states:
+                    hidden_states_outputs.append(hidden_states)
+                output_per_layers.append(output_this_layer)
 
         # Final layer norm.
         logits = self.final_layernorm(hidden_states)
-        
+
         if 'final_forward' in self.hooks:
             logits_parallel = self.hooks['final_forward'](logits, **kw_args)
         else:
             logits_parallel = copy_to_model_parallel_region(logits)
             logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
-            
+
         # branch related embedding
         if branch_input is None and 'branch_final_forward' in self.hooks:
             branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
 
         if not self.parallel_output:
             logits_parallel = gather_from_model_parallel_region(logits_parallel)
-            
+
+        outputs = [logits_parallel]
         if branch_input is not None:
-            return (logits_parallel, branch_input, *output_per_layers)
-        
-        return (logits_parallel, *output_per_layers)
-        
+            outputs.append(branch_input)
+        if output_hidden_states:
+            outputs.append(hidden_states_outputs)
+        outputs.extend(output_per_layers)
+
+        return outputs
+