From 2721c2faeff7d1a85e52ad072b9a1442d347fcbc Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Thu, 25 Nov 2021 16:21:45 +0000 Subject: [PATCH] fix shared crossattn bug --- .../model/encoder_decoder_model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index 708bc69..7ea50b1 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -18,17 +18,23 @@ from .common_layers import CrossAttention, LayerNorm class CrossAttentionMixin(BaseMixin): - def __init__(self, hidden_size, num_attention_heads, + def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): super().__init__() - self.cross_attention = CrossAttention( + + self.cross_attentions = torch.nn.ModuleList( + [CrossAttention( hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, output_layer_init_method=output_layer_init_method - ) # Just copy args - self.cross_ln = LayerNorm(hidden_size, 1e-5) + )] for layer_id in range(num_layers) + ) # Just copy args + self.cross_lns = torch.nn.ModuleList( + [LayerNorm(hidden_size, 1e-5)] + for layer_id in range(num_layers) + ) def layer_forward(self, hidden_states, mask, layer_id, **kw_args): @@ -49,8 +55,8 @@ class CrossAttentionMixin(BaseMixin): hidden_states = hidden_states + attention_output # Cross attention. - layernorm_output = self.cross_ln(hidden_states) - cross_attn_output = self.cross_attention( + layernorm_output = self.cross_lns[layer_id](hidden_states) + cross_attn_output = self.cross_attentions[layer_id]( layernorm_output, torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), encoder_outputs -- GitLab