diff --git a/SwissArmyTransformer/arguments.py b/SwissArmyTransformer/arguments.py index 5eafe851c4d6064d802532c2b822efe4cefc42a1..804bae08edbd705097b8fb2828443b87066e8a9c 100755 --- a/SwissArmyTransformer/arguments.py +++ b/SwissArmyTransformer/arguments.py @@ -300,25 +300,28 @@ def get_args(args_list=None): print('using world size: {} and model-parallel size: {} '.format( args.world_size, args.model_parallel_size)) - if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_config is not None: - with open(args.deepspeed_config) as file: - deepspeed_config = json.load(file) - if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: - args.fp16 = True - else: - args.fp16 = False + if hasattr(args, "deepspeed") and args.deepspeed: if args.checkpoint_activations: args.deepspeed_activation_checkpointing = True - if "train_micro_batch_size_per_gpu" in deepspeed_config: - args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] - if "gradient_accumulation_steps" in deepspeed_config: - args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] else: - args.gradient_accumulation_steps = None - if "optimizer" in deepspeed_config: - optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) - args.lr = optimizer_params_config.get("lr", args.lr) - args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) + args.deepspeed_activation_checkpointing = False + if args.deepspeed_config is not None: + with open(args.deepspeed_config) as file: + deepspeed_config = json.load(file) + if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: + args.fp16 = True + else: + args.fp16 = False + if "train_micro_batch_size_per_gpu" in deepspeed_config: + args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] + if "gradient_accumulation_steps" in deepspeed_config: + args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] + else: + args.gradient_accumulation_steps = None + if "optimizer" in deepspeed_config: + optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) + args.lr = optimizer_params_config.get("lr", args.lr) + args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) return args diff --git a/SwissArmyTransformer/model/encoder_decoder_model.py b/SwissArmyTransformer/model/encoder_decoder_model.py index 0c6336c66eb887792e50872d9af768445c5ce030..341a800083d5a317dfc55cde1592915a8296c4ca 100644 --- a/SwissArmyTransformer/model/encoder_decoder_model.py +++ b/SwissArmyTransformer/model/encoder_decoder_model.py @@ -64,11 +64,22 @@ class EncoderDecoderModel(torch.nn.Module): return encoder_outputs def decode(self, input_ids, position_ids, attention_mask, encoder_outputs,cross_attention_mask=None, **kw_args): + if attention_mask is None: + batch_size, seq_length = input_ids.size()[:2] + seq_ids = torch.arange(seq_length, device=input_ids.device) + attention_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + attention_mask = attention_mask.to(self.decoder.transformer.word_embeddings.weight.dtype) + attention_mask = attention_mask[:, None, :, :] # If no context, please explicitly pass ``encoder_outputs=None'' return self.decoder(input_ids, position_ids, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) - def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids,dec_attention_mask, *, enc_attention_mask=None, cross_attention_mask=None, **kw_args): + def forward(self, enc_input_ids, enc_position_ids, dec_input_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args): # Please use self.decoder for auto-regressive generation. + batch_size, seq_length = enc_input_ids.size()[:2] + if enc_attention_mask is None: + enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device) + if cross_attention_mask is None: + cross_attention_mask = enc_attention_mask encoder_outputs = self.encode(enc_input_ids, enc_position_ids, enc_attention_mask, **kw_args) decoder_outputs, *mems = self.decode(dec_input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) return encoder_outputs, decoder_outputs, *mems diff --git a/SwissArmyTransformer/model/t5_model.py b/SwissArmyTransformer/model/t5_model.py index 944a133371fe591e8f20c6f47c9350040df519a5..531ec54b0062fc616d46a42652a8d486a4659e33 100644 --- a/SwissArmyTransformer/model/t5_model.py +++ b/SwissArmyTransformer/model/t5_model.py @@ -180,7 +180,7 @@ class T5DecoderFinalMixin(BaseMixin): class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): - super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, + super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu) self.encoder.add_mixin( "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) @@ -205,3 +205,25 @@ class T5Model(EncoderDecoderModel): def add_model_specific_args(cls, parser): super().add_model_specific_args(parser) parser.add_argument("--relative-attention-num-buckets", type=int, default=None) + + def encode(self, input_ids, attention_mask=None, **kw_args): + return super().encode(input_ids, None, attention_mask, **kw_args) + + def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args): + return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs, + cross_attention_mask=cross_attention_mask, **kw_args) + + def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, + dec_attention_mask=None, cross_attention_mask=None, **kw_args): + batch_size, seq_length = enc_input_ids.size()[:2] + if enc_attention_mask is None: + enc_attention_mask = torch.ones(1, 1, 1, seq_length, + dtype=self.encoder.transformer.word_embeddings.weight.dtype, + device=enc_input_ids.device) + if cross_attention_mask is None: + cross_attention_mask = enc_attention_mask + encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args) + decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask, + encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, + **kw_args) + return encoder_outputs, decoder_outputs, *mems