From fdcf47352a97f7e793e53ecdb83e2a2e69259805 Mon Sep 17 00:00:00 2001 From: duzx16 <zx-du20@mails.tsinghua.edu.cn> Date: Fri, 10 Dec 2021 11:24:52 +0800 Subject: [PATCH] Update script for backward test --- examples/t5/inference_t5.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py index 38a008a..b902f24 100644 --- a/examples/t5/inference_t5.py +++ b/examples/t5/inference_t5.py @@ -29,6 +29,34 @@ from SwissArmyTransformer.generation.utils import timed_name, generate_continual from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer +def decoder_shift_right(input_ids, args): + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = args.decoder_start_token_id + return shifted_input_ids + + +def get_batch(data, args): + keys = ['text', 'loss_mask', 'target', 'attention_mask'] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + # Unpack. + tokens = data_b['text'].long() + labels = data_b['target'].long() + decoder_tokens = decoder_shift_right(labels, args) + attention_mask = data_b['attention_mask'].long() + loss_mask = data_b['loss_mask'].float() + + # Convert + if args.fp16: + attention_mask = attention_mask.half() + elif args.bf16: + attention_mask = attention_mask.bfloat16() + return tokens, decoder_tokens, labels, loss_mask, attention_mask + + def get_masks_and_position_ids_glm(seq, mask_position, context_length): tokens = seq.unsqueeze(0) -- GitLab