diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index ddcc8fdeb8f9368d55a4df59f4dd589b4ab0e024..fb7c7cb3d08e8e7d083faccdb3b0ae4551d4ce30 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -177,6 +177,10 @@ class T5DecoderFinalMixin(BaseMixin): def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True): super().__init__() self.hidden_size = hidden_size + self.tie_word_embeddings = tie_word_embeddings + if not tie_word_embeddings: + self.lm_head = VocabParallelEmbedding( + vocab_size, hidden_size, init_method=unscaled_init_method(0.02)) def final_forward(self, logits, **kwargs): logits_parallel = copy_to_model_parallel_region(logits)