diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py index 38a008a8e8ead6ff4a0f4fda3938333ffebaf9df..b902f248e0c6ebe6dfccba3805d4112923eff6fa 100644 --- a/examples/t5/inference_t5.py +++ b/examples/t5/inference_t5.py @@ -29,6 +29,34 @@ from SwissArmyTransformer.generation.utils import timed_name, generate_continual from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer +def decoder_shift_right(input_ids, args): + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = args.decoder_start_token_id + return shifted_input_ids + + +def get_batch(data, args): + keys = ['text', 'loss_mask', 'target', 'attention_mask'] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens = data_b['text'].long() + labels = data_b['target'].long() + decoder_tokens = decoder_shift_right(labels, args) + attention_mask = data_b['attention_mask'].long() + loss_mask = data_b['loss_mask'].float() + + # Convert + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + return tokens, decoder_tokens, labels, loss_mask, attention_mask + + def get_masks_and_position_ids_glm(seq, mask_position, context_length): tokens = seq.unsqueeze(0)