Skip to content
Snippets Groups Projects
Commit 2721c2fa authored by Ming Ding's avatar Ming Ding
Browse files

fix shared crossattn bug

parent 233d9495
No related branches found
No related tags found
No related merge requests found
...@@ -18,17 +18,23 @@ from .common_layers import CrossAttention, LayerNorm ...@@ -18,17 +18,23 @@ from .common_layers import CrossAttention, LayerNorm
class CrossAttentionMixin(BaseMixin): class CrossAttentionMixin(BaseMixin):
def __init__(self, hidden_size, num_attention_heads, def __init__(self, num_layers, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob, attention_dropout_prob, output_dropout_prob,
init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None): init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
super().__init__() super().__init__()
self.cross_attention = CrossAttention(
self.cross_attentions = torch.nn.ModuleList(
[CrossAttention(
hidden_size, num_attention_heads, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob, attention_dropout_prob, output_dropout_prob,
init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size,
output_layer_init_method=output_layer_init_method output_layer_init_method=output_layer_init_method
) # Just copy args )] for layer_id in range(num_layers)
self.cross_ln = LayerNorm(hidden_size, 1e-5) ) # Just copy args
self.cross_lns = torch.nn.ModuleList(
[LayerNorm(hidden_size, 1e-5)]
for layer_id in range(num_layers)
)
def layer_forward(self, hidden_states, mask, layer_id, **kw_args): def layer_forward(self, hidden_states, mask, layer_id, **kw_args):
...@@ -49,8 +55,8 @@ class CrossAttentionMixin(BaseMixin): ...@@ -49,8 +55,8 @@ class CrossAttentionMixin(BaseMixin):
hidden_states = hidden_states + attention_output hidden_states = hidden_states + attention_output
# Cross attention. # Cross attention.
layernorm_output = self.cross_ln(hidden_states) layernorm_output = self.cross_lns[layer_id](hidden_states)
cross_attn_output = self.cross_attention( cross_attn_output = self.cross_attentions[layer_id](
layernorm_output, layernorm_output,
torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype),
encoder_outputs encoder_outputs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment