diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index 708bc694492643f2d8787968bb3430d1752a0416..7ea50b1815641e007935aed2cb892c3574c78668 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -18,17 +18,23 @@ from .common_layers import CrossAttention, LayerNorm class CrossAttentionMixin(BaseMixin): - def __init__(self, hidden_size, num_attention_heads, + def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): super().__init__() - self.cross_attention = CrossAttention( + + self.cross_attentions = torch.nn.ModuleList( + [CrossAttention( hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, output_layer_init_method=output_layer_init_method - ) # Just copy args - self.cross_ln = LayerNorm(hidden_size, 1e-5) + )] for layer_id in range(num_layers) + ) # Just copy args + self.cross_lns = torch.nn.ModuleList( + [LayerNorm(hidden_size, 1e-5)] + for layer_id in range(num_layers) + ) def layer_forward(self, hidden_states, mask, layer_id, **kw_args): @@ -49,8 +55,8 @@ class CrossAttentionMixin(BaseMixin): hidden_states = hidden_states + attention_output # Cross attention. - layernorm_output = self.cross_ln(hidden_states) - cross_attn_output = self.cross_attention( + layernorm_output = self.cross_lns[layer_id](hidden_states) + cross_attn_output = self.cross_attentions[layer_id]( layernorm_output, torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), encoder_outputs