diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index 7ea50b1815641e007935aed2cb892c3574c78668..dbc7cbcd3f469b0ac7563395ba8a784c2b04dabb 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -92,6 +92,7 @@ class DecoderModel(BaseModel): super().__init__(dec_args, transformer=transformer) self.add_mixin('cross_attention', CrossAttentionMixin( + dec_args.num_layers, dec_args.hidden_size, dec_args.num_attention_heads, dec_args.attention_dropout, dec_args.hidden_dropout, self.transformer.init_method,