From b4eaba16b27e76f4b89889d8fc964d7e90b6b66f Mon Sep 17 00:00:00 2001
From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn>
Date: Mon, 8 Nov 2021 11:16:05 +0800
Subject: [PATCH] Add parallel_output argument for GLMModel

---
 SwissArmyTransformer/model/glm_model.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/SwissArmyTransformer/model/glm_model.py b/SwissArmyTransformer/model/glm_model.py
index 076ee44..01aa075 100644
--- a/SwissArmyTransformer/model/glm_model.py
+++ b/SwissArmyTransformer/model/glm_model.py
@@ -19,8 +19,9 @@ class BlockPositionEmbeddingMixin(BaseMixin):
         return position_embeddings + block_position_embeddings
 
 class GLMModel(BaseModel):
-    def __init__(self, args, transformer=None):
-        super().__init__(args, transformer=transformer)
+    def __init__(self, args, transformer=None, parallel_output=True):
+        super().__init__(args, transformer=transformer, parallel_output=parallel_output
+                         )
         self.add_mixin('block_position_embedding', 
             BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)
         )
-- 
GitLab