diff --git a/SwissArmyTransformer/model/glm_model.py b/SwissArmyTransformer/model/glm_model.py index 076ee44c58f526ecf39343f0b127c0c2920e0a17..01aa07554d69cc9caee0479f99efbe5098aa8f3c 100644 --- a/SwissArmyTransformer/model/glm_model.py +++ b/SwissArmyTransformer/model/glm_model.py @@ -19,8 +19,9 @@ class BlockPositionEmbeddingMixin(BaseMixin): return position_embeddings + block_position_embeddings class GLMModel(BaseModel): - def __init__(self, args, transformer=None): - super().__init__(args, transformer=transformer) + def __init__(self, args, transformer=None, parallel_output=True): + super().__init__(args, transformer=transformer, parallel_output=parallel_output + ) self.add_mixin('block_position_embedding', BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size) )