diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py
index ddcc8fdeb8f9368d55a4df59f4dd589b4ab0e024..fb7c7cb3d08e8e7d083faccdb3b0ae4551d4ce30 100644
--- a/SwissArmyTransformer/model/t5_model.py
+++ b/SwissArmyTransformer/model/t5_model.py
@@ -177,6 +177,10 @@ class T5DecoderFinalMixin(BaseMixin):
     def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True):
         super().__init__()
         self.hidden_size = hidden_size
+        self.tie_word_embeddings = tie_word_embeddings
+        if not tie_word_embeddings:
+            self.lm_head = VocabParallelEmbedding(
+                vocab_size, hidden_size, init_method=unscaled_init_method(0.02))
 
     def final_forward(self, logits, **kwargs):
         logits_parallel = copy_to_model_parallel_region(logits)