diff --git a/SwissArmyTransformer/generation/autoregressive_sampling.py b/SwissArmyTransformer/generation/autoregressive_sampling.py
index 4eada99a05d6fe244fe8f6c342322e6f6c7c485f..eb5acf42c303db3340e64249108f7f8e697293b4 100644
--- a/SwissArmyTransformer/generation/autoregressive_sampling.py
+++ b/SwissArmyTransformer/generation/autoregressive_sampling.py
@@ -100,7 +100,7 @@ def filling_sequence(
         else:
             log_attention_weights_part = None
 
-        logits, *mem_kv = model(
+        logits, *output_per_layers = model(
             tokens[:, index:], 
             position_ids[..., index: counter+1],
             attention_mask[..., index: counter+1, :counter+1], # TODO memlen
@@ -108,6 +108,7 @@ def filling_sequence(
             log_attention_weights=log_attention_weights_part,
             **kw_args
         )
+        mem_kv = [o['mem_kv'] for o in output_per_layers]
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         counter += 1
         index = counter
diff --git a/SwissArmyTransformer/model/base_model.py b/SwissArmyTransformer/model/base_model.py
index 1b0569c61a0527b3a438428bd25ff1d3afc64351..61543f5d30ed922921b6cd9804b08550e894ca00 100644
--- a/SwissArmyTransformer/model/base_model.py
+++ b/SwissArmyTransformer/model/base_model.py
@@ -39,9 +39,9 @@ class BaseMixin(torch.nn.Module):
     # Eg., 
     # 
     # @non_conflict
-    # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kwargs):
+    # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
     #     new_q, new_k, new_v = pre_hack(q, k, v)
-    #     attn_result = old_impl(q, k, v, mask, dropout_fn, **kwargs)
+    #     attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
     #     attn_result = post_hack(attn_result)
     #     return attn_result
 
diff --git a/SwissArmyTransformer/model/cached_autoregressive_model.py b/SwissArmyTransformer/model/cached_autoregressive_model.py
index 296a811ea17662c2d4432b029412017fd865210a..ed25fe48aaa3ef83b6c784429b229d40a22b7020 100755
--- a/SwissArmyTransformer/model/cached_autoregressive_model.py
+++ b/SwissArmyTransformer/model/cached_autoregressive_model.py
@@ -18,38 +18,24 @@ from SwissArmyTransformer.mpu.transformer import standard_attention, split_tenso
 
 class CachedAutoregressiveMixin(BaseMixin):
     def __init__(self):
-        super().__init__()
-        
-    def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
-        attn_module = self.transformer.layers[layer_id].attention
-        mem = mems[layer_id] if mems is not None else None
-        
-        mixed_raw_layer = attn_module.query_key_value(hidden_states)
-        (mixed_query_layer,
-            mixed_key_layer,
-            mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
+        super().__init__()     
+           
+    @non_conflict
+    def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, old_impl=standard_attention, **kw_args):
+        mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
+        b, nh, seq_len, hidden_size = k.shape
+
+        cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
+        kw_args['output_this_layer']['mem_kv'] = cache_kv
         
         if mem is not None: # the first time, mem is None
-            b = mixed_key_layer.shape[0] # might change batch_size
-            memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2)
-            mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1)
-            mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
+            # might change batch_size
+            mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) 
+            memk, memv = mem[0], mem[1]
+            k = torch.cat((memk, k), dim=2)
+            v = torch.cat((memv, v), dim=2)
+        return old_impl(q, k, v, mask, dropout_fn, **kw_args)
 
-        # same as training
-        query_layer = attn_module._transpose_for_scores(mixed_query_layer)
-        key_layer = attn_module._transpose_for_scores(mixed_key_layer)
-        value_layer = attn_module._transpose_for_scores(mixed_value_layer)
-        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights)
-        
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
-        context_layer = context_layer.view(*new_context_layer_shape)
-        output = attn_module.dense(context_layer)
-        
-        # new mem this layer
-        new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
-            
-        return output, new_mem
 
 class CachedAutoregressiveModel(BaseModel):
     def __init__(self, args, transformer=None):
diff --git a/SwissArmyTransformer/model/cuda2d_model.py b/SwissArmyTransformer/model/cuda2d_model.py
index cda027c5d8bcdb46cb0d343f4c0a0338dfccb5e0..4fdc87d3bee12f1b132f25f979d33c2582d36fa8 100644
--- a/SwissArmyTransformer/model/cuda2d_model.py
+++ b/SwissArmyTransformer/model/cuda2d_model.py
@@ -88,7 +88,7 @@ class Cuda2dModel(BaseModel):
         output_1 = dense_plus(context_layer1)
         output = torch.cat((output_0, output_1), dim=1)
         
-        return output, None
+        return output
     
     def disable_untrainable_params(self):
         self.transformer.requires_grad_(False)
diff --git a/SwissArmyTransformer/mpu/transformer.py b/SwissArmyTransformer/mpu/transformer.py
index 229e21abbcc941a8808c4c0a5580ef956311dddb..c553cc21daf4e96c7ce79be41dd08e1cb5cb1c47 100755
--- a/SwissArmyTransformer/mpu/transformer.py
+++ b/SwissArmyTransformer/mpu/transformer.py
@@ -122,7 +122,7 @@ class SelfAttention(torch.nn.Module):
 
     def forward(self, hidden_states, mask, *args, **kw_args):
         if 'attention_forward' in self.hooks:
-            return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
+            return self.hooks['attention_forward'](hidden_states, mask, **kw_args)
         else:
             attention_fn = standard_attention
             if 'attention_fn' in self.hooks:
@@ -139,7 +139,7 @@ class SelfAttention(torch.nn.Module):
             key_layer = self._transpose_for_scores(mixed_key_layer)
             value_layer = self._transpose_for_scores(mixed_value_layer)
 
-            context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, layer_id=self.layer_id, **kw_args)
+            context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
 
             context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
             new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
@@ -149,7 +149,7 @@ class SelfAttention(torch.nn.Module):
             if self.training:
                 output = self.output_dropout(output)
 
-            return output, None
+            return output
 
 
 class CrossAttention(torch.nn.Module):
@@ -266,7 +266,7 @@ class MLP(torch.nn.Module):
 
     def forward(self, hidden_states, **kw_args):
         if 'mlp_forward' in self.hooks:
-            output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id)
+            output = self.hooks['mlp_forward'](hidden_states, **kw_args)
         else:
             intermediate_parallel = self.dense_h_to_4h(hidden_states)
             intermediate_parallel = self.activation_func(intermediate_parallel)
@@ -367,7 +367,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm at the begining of the transformer layer.
         layernorm_output1 = self.input_layernorm(hidden_states)
         # Self attention.
-        attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args)
+        attention_output = self.attention(layernorm_output1, mask, **kw_args)
 
         # Third LayerNorm
         if self.sandwich_ln:
@@ -381,7 +381,7 @@ class BaseTransformerLayer(torch.nn.Module):
         if self.is_decoder:
             encoder_outputs = kw_args['encoder_outputs']
             if encoder_outputs is not None:
-                cross_attention_mask=kw_args['cross_attention_mask']
+                cross_attention_mask = kw_args['cross_attention_mask']
                 # Cross attention
                 attention_output = self.cross_attention(layernorm_output, cross_attention_mask, encoder_outputs, **kw_args)
                 # Residual connection.
@@ -399,7 +399,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Second residual connection.
         output = layernorm_input + mlp_output
 
-        return output, output_this_layer  # temporally, output_this_layer is only from attention
+        return output, kw_args['output_this_layer'], kw_args['output_cross_layer']
 
 
 class BaseTransformer(torch.nn.Module):
@@ -475,7 +475,7 @@ class BaseTransformer(torch.nn.Module):
         # Final layer norm before output.
         self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
 
-    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, 
+    def forward(self, input_ids, position_ids, attention_mask, *,
                 output_hidden_states=False, **kw_args):
         # sanity check
         assert len(input_ids.shape) == 2
@@ -486,11 +486,7 @@ class BaseTransformer(torch.nn.Module):
             ) # None means full attention
         assert len(attention_mask.shape) == 2 or \
                len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
-        assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
-        # branch_input is a new part of input need layer-by-layer update,
-        #   but with different hidden_dim and computational routine.
-        #   In most cases, you can just ignore it.
-
+    
         # embedding part
         if 'word_embedding_forward' in self.hooks:
             hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
@@ -507,64 +503,126 @@ class BaseTransformer(torch.nn.Module):
             hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
 
-        hidden_states_outputs = [hidden_states] if output_hidden_states else []
-        # branch related embedding
-        if branch_input is None and 'branch_embedding_forward' in self.hooks:
-            branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args)
-
-        # define custom_forward for checkpointing
+        # initial output_cross_layer
+        if 'cross_layer_embedding_forward' in self.hooks:
+            output_cross_layer = self.hooks['cross_layer_embedding_forward'](hidden_states, **kw_args)
+        else:
+            output_cross_layer = {}
+            
         output_per_layers = []
         if self.checkpoint_activations:
-            def custom(start, end):
+            # define custom_forward for checkpointing
+            def custom(start, end, kw_args_index, cross_layer_index):
                 def custom_forward(*inputs):
                     layers_ = self.layers[start:end]
                     x_, mask = inputs[0], inputs[1]
-                    if len(inputs) > 2: # have branch_input
-                        branch_ = inputs[2]
+                    
+                    # recover kw_args and output_cross_layer
+                    flat_inputs = inputs[2:]
+                    kw_args, output_cross_layer = {}, {}
+                    for k, idx in kw_args_index.items():
+                        kw_args[k] = flat_inputs[idx]
+                    for k, idx in cross_layer_index.items():
+                        output_cross_layer[k] = flat_inputs[idx]
+                    # -----------------
+                    
                     output_per_layers_part = []
                     for i, layer in enumerate(layers_):
-                        if len(inputs) > 2:
-                            x_, branch_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
-                            )
-                        elif 'layer_forward' in self.hooks:
-                            x_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, **kw_args
+                        if 'layer_forward' in self.hooks:
+                            layer_ret = self.hooks['layer_forward'](
+                                x_, mask, layer_id=layer.layer_id, 
+                                **kw_args, **output_cross_layer, 
+                                output_this_layer={}, output_cross_layer={}
                             )
                         else:
-                            x_, output_this_layer = layer(x_, mask, **kw_args)
+                            layer_ret = layer(
+                                x_, mask, layer_id=layer.layer_id, 
+                                **kw_args, **output_cross_layer, 
+                                output_this_layer={}, output_cross_layer={}
+                            )
+                        if torch.is_tensor(layer_ret): # only hidden_states
+                            x_, output_this_layer, output_cross_layer = layer_ret, {}, {}
+                        elif len(layer_ret) == 2: # hidden_states & output_this_layer
+                            x_, output_this_layer = layer_ret
+                            output_cross_layer = {}
+                        elif len(layer_ret) == 3: 
+                            x_, output_this_layer, output_cross_layer = layer_ret
+                        assert isinstance(output_this_layer, dict)
+                        assert isinstance(output_cross_layer, dict)
+                        if output_hidden_states:
+                            output_this_layer['hidden_states'] = x_
                         output_per_layers_part.append(output_this_layer)
-                    return x_, output_per_layers_part
+
+                    # flatten for re-aggregate keywords outputs
+                    flat_outputs = []
+                    for output_this_layer in output_per_layers_part:
+                        for k in output_this_layer:
+                            # TODO add warning for depth>=2 grad tensors
+                            flat_outputs.append(output_this_layer[k])
+                            output_this_layer[k] = len(flat_outputs) - 1
+                    for k in output_cross_layer:
+                        flat_outputs.append(output_cross_layer[k])
+                        output_cross_layer[k] = len(flat_outputs) - 1
+                    # --------------------
+
+                    return x_, output_per_layers_part, output_cross_layer, flat_outputs
                 return custom_forward
 
             # prevent to lose requires_grad in checkpointing.
             # To save memory when only finetuning the final layers, don't use checkpointing.
             if self.training:
                 hidden_states.requires_grad_(True)
-
+            
             l, num_layers = 0, len(self.layers)
             chunk_length = self.checkpoint_num_layers
             while l < num_layers:
                 args = [hidden_states, attention_mask]
-                if branch_input is not None:
-                    hidden_states, branch_input, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args, branch_input)
-                else:
-                    hidden_states, output_per_layers_part = checkpoint(custom(l, l + chunk_length), *args)
-                if output_hidden_states:
-                    hidden_states_outputs.append(hidden_states)
+                # flatten kw_args and output_cross_layer
+                flat_inputs, kw_args_index, cross_layer_index = [], {}, {}
+                for k, v in kw_args.items():
+                    flat_inputs.append(v)
+                    kw_args_index[k] = len(flat_inputs) - 1
+                for k, v in output_cross_layer.items():
+                    flat_inputs.append(v)
+                    cross_layer_index[k] = len(flat_inputs) - 1
+                # --------------------
+
+                hidden_states, output_per_layers_part, output_cross_layer, flat_outputs = \
+                    checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs)
+
+                # recover output_per_layers_part, output_cross_layer
+                for output_this_layer in output_per_layers_part:
+                    for k in output_this_layer:
+                        output_this_layer[k] = flat_outputs[output_this_layer[k]]
+                for k in output_cross_layer:
+                    output_cross_layer[k] = flat_outputs[output_cross_layer[k]]
+                # --------------------
+
                 output_per_layers.extend(output_per_layers_part)
                 l += chunk_length
         else:
             for i, layer in enumerate(self.layers):
                 args = [hidden_states, attention_mask]
-                if branch_input is not None: # customized layer_forward with branch_input
-                    hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
-                elif 'layer_forward' in self.hooks: # customized layer_forward
-                    hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
+
+                if 'layer_forward' in self.hooks: # customized layer_forward
+                    layer_ret = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), 
+                        **kw_args, 
+                        **output_cross_layer, 
+                        output_this_layer={}, output_cross_layer={}
+                    )
                 else:
-                    hidden_states, output_this_layer = layer(*args, **kw_args)
+                    layer_ret = layer(*args, layer_id=torch.tensor(i), **kw_args, **output_cross_layer, 
+                        output_this_layer={}, output_cross_layer={})
+                if torch.is_tensor(layer_ret): # only hidden_states
+                    hidden_states, output_this_layer, output_cross_layer = layer_ret, {}, {}
+                elif len(layer_ret) == 2: # hidden_states & output_this_layer
+                    hidden_states, output_this_layer = layer_ret
+                    output_cross_layer = {}
+                elif len(layer_ret) == 3: 
+                    hidden_states, output_this_layer, output_cross_layer = layer_ret
+
                 if output_hidden_states:
-                    hidden_states_outputs.append(hidden_states)
+                    output_this_layer['hidden_states'] = hidden_states
                 output_per_layers.append(output_this_layer)
 
         # Final layer norm.
@@ -576,18 +634,10 @@ class BaseTransformer(torch.nn.Module):
             logits_parallel = copy_to_model_parallel_region(logits)
             logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
 
-        # branch related embedding
-        if branch_input is None and 'branch_final_forward' in self.hooks:
-            branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
-
         if not self.parallel_output:
             logits_parallel = gather_from_model_parallel_region(logits_parallel)
 
         outputs = [logits_parallel]
-        if branch_input is not None:
-            outputs.append(branch_input)
-        if output_hidden_states:
-            outputs.append(hidden_states_outputs)
         outputs.extend(output_per_layers)
 
         return outputs