diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index dbc7cbcd3f469b0ac7563395ba8a784c2b04dabb..e70245368e68e6d43ee06d97a959868194b712d6 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -29,11 +29,11 @@ class CrossAttentionMixin(BaseMixin): attention_dropout_prob, output_dropout_prob, init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, output_layer_init_method=output_layer_init_method - )] for layer_id in range(num_layers) + ) for layer_id in range(num_layers)] ) # Just copy args self.cross_lns = torch.nn.ModuleList( - [LayerNorm(hidden_size, 1e-5)] - for layer_id in range(num_layers) + [LayerNorm(hidden_size, 1e-5) + for layer_id in range(num_layers)] )