From 91afa917ed43cee124a609b469b4f6c005d1b5c7 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Sun, 12 Dec 2021 22:53:55 +0800 Subject: [PATCH] Add no share embeddings for T5 --- SwissArmyTransformer/model/t5_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index ddcc8fd..fb7c7cb 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -177,6 +177,10 @@ class T5DecoderFinalMixin(BaseMixin): def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True): super().__init__() self.hidden_size = hidden_size + self.tie_word_embeddings = tie_word_embeddings + if not tie_word_embeddings: + self.lm_head = VocabParallelEmbedding( + vocab_size, hidden_size, init_method=unscaled_init_method(0.02)) def final_forward(self, logits, **kwargs): logits_parallel = copy_to_model_parallel_region(logits) -- GitLab