From 2721c2faeff7d1a85e52ad072b9a1442d347fcbc Mon Sep 17 00:00:00 2001
From: Ming Ding <dm_thu@qq.com>
Date: Thu, 25 Nov 2021 16:21:45 +0000
Subject: [PATCH] fix shared crossattn bug

---
 .../model/encoder_decoder_model.py             | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py
index 708bc69..7ea50b1 100644
--- a/SwissArmyTransformer/model/encoder_decoder_model.py
+++ b/SwissArmyTransformer/model/encoder_decoder_model.py
@@ -18,17 +18,23 @@ from .common_layers import CrossAttention, LayerNorm
 
 
 class CrossAttentionMixin(BaseMixin):
-    def __init__(self, hidden_size, num_attention_heads,
+    def __init__(self, num_layers, hidden_size, num_attention_heads,
                 attention_dropout_prob, output_dropout_prob,
                 init_method, enc_hidden_size=None, inner_hidden_size=None, output_layer_init_method=None):
         super().__init__()
-        self.cross_attention = CrossAttention(
+            
+        self.cross_attentions = torch.nn.ModuleList(
+            [CrossAttention(
                 hidden_size, num_attention_heads,
                 attention_dropout_prob, output_dropout_prob,
                 init_method, enc_hidden_size=enc_hidden_size, inner_hidden_size=inner_hidden_size, 
                 output_layer_init_method=output_layer_init_method
-            ) # Just copy args
-        self.cross_ln = LayerNorm(hidden_size, 1e-5)
+            )] for layer_id in range(num_layers)
+        ) # Just copy args
+        self.cross_lns = torch.nn.ModuleList(
+            [LayerNorm(hidden_size, 1e-5)]
+            for layer_id in range(num_layers)
+        )
         
 
     def layer_forward(self, hidden_states, mask, layer_id, **kw_args):
@@ -49,8 +55,8 @@ class CrossAttentionMixin(BaseMixin):
         hidden_states = hidden_states + attention_output
 
         # Cross attention.
-        layernorm_output = self.cross_ln(hidden_states)
-        cross_attn_output = self.cross_attention(
+        layernorm_output = self.cross_lns[layer_id](hidden_states)
+        cross_attn_output = self.cross_attentions[layer_id](
             layernorm_output, 
             torch.ones(1, 1, device=hidden_states.device, dtype=hidden_states.dtype), 
             encoder_outputs
-- 
GitLab