From 91afa917ed43cee124a609b469b4f6c005d1b5c7 Mon Sep 17 00:00:00 2001
From: duzx16 <zx-du20@mails.tsinghua.edu.cn>
Date: Sun, 12 Dec 2021 22:53:55 +0800
Subject: [PATCH] Add no share embeddings for T5

---
 SwissArmyTransformer/model/t5_model.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index ddcc8fd..fb7c7cb 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -177,6 +177,10 @@ class T5DecoderFinalMixin(BaseMixin):
     def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
         super().__init__()
         self.hidden_size = hidden_size
+        self.tie_word_embeddings = tie_word_embeddings
+        if not tie_word_embeddings:
+            self.lm_head = VocabParallelEmbedding(
+                vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
 
     def final_forward(self, logits, **kwargs):
         logits_parallel = copy_to_model_parallel_region(logits)
-- 
GitLab