From b4eaba16b27e76f4b89889d8fc964d7e90b6b66f Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Mon, 8 Nov 2021 11:16:05 +0800 Subject: [PATCH] Add parallel_output argument for GLMModel --- SwissArmyTransformer/model/glm_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/SwissArmyTransformer/model/glm_model.py b/SwissArmyTransformer/model/glm_model.py index 076ee44..01aa075 100644 --- a/SwissArmyTransformer/model/glm_model.py +++ b/SwissArmyTransformer/model/glm_model.py @@ -19,8 +19,9 @@ class BlockPositionEmbeddingMixin(BaseMixin): return position_embeddings + block_position_embeddings class GLMModel(BaseModel): - def __init__(self, args, transformer=None): - super().__init__(args, transformer=transformer) + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output + ) self.add_mixin('block_position_embedding', BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size) ) -- GitLab