diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index 7ea50b1815641e007935aed2cb892c3574c78668..dbc7cbcd3f469b0ac7563395ba8a784c2b04dabb 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -92,6 +92,7 @@ class DecoderModel(BaseModel):
         super().__init__(dec_args, transformer=transformer)
         self.add_mixin('cross_attention',
             CrossAttentionMixin(
+                dec_args.num_layers,
                 dec_args.hidden_size, dec_args.num_attention_heads,
                 dec_args.attention_dropout, dec_args.hidden_dropout,
                 self.transformer.init_method,