diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index 708bc694492643f2d8787968bb3430d1752a0416..7ea50b1815641e007935aed2cb892c3574c78668 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -18,17 +18,23 @@ from .common_layers import CrossAttention, LayerNorm
 
 
 class CrossAttentionMixin(BaseMixin):
-    def __init__(self, hidden_size, num_attention_heads,
+    def __init__(self, num_layers, hidden_size, num_attention_heads,
                 attention_dropout_prob, output_dropout_prob,
                 init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
         super().__init__()
-        self.cross_attention = CrossAttention(
+            
+        self.cross_attentions = torch.nn.ModuleList(
+            [CrossAttention(
                 hidden_size, num_attention_heads,
                 attention_dropout_prob, output_dropout_prob,
                 init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, 
                 output_layer_init_method=output_layer_init_method
-            ) # Just copy args
-        self.cross_ln = LayerNorm(hidden_size, 1e-5)
+            )] for layer_id in range(num_layers)
+        ) # Just copy args
+        self.cross_lns = torch.nn.ModuleList(
+            [LayerNorm(hidden_size, 1e-5)]
+            for layer_id in range(num_layers)
+        )
         
 
     def layer_forward(self, hidden_states, mask, layer_id, **kw_args):
@@ -49,8 +55,8 @@ class CrossAttentionMixin(BaseMixin):
         hidden_states = hidden_states + attention_output
 
         # Cross attention.
-        layernorm_output = self.cross_ln(hidden_states)
-        cross_attn_output = self.cross_attention(
+        layernorm_output = self.cross_lns[layer_id](hidden_states)
+        cross_attn_output = self.cross_attentions[layer_id](
             layernorm_output, 
             torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), 
             encoder_outputs