Skip to content
Snippets Groups Projects
Commit fdcf4735 authored by duzx16's avatar duzx16
Browse files

Update script for backward test

parent adf4133f
No related branches found
No related tags found
No related merge requests found
...@@ -29,6 +29,34 @@ from SwissArmyTransformer.generation.utils import timed_name, generate_continual ...@@ -29,6 +29,34 @@ from SwissArmyTransformer.generation.utils import timed_name, generate_continual
from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer
def decoder_shift_right(input_ids, args):
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = args.decoder_start_token_id
return shifted_input_ids
def get_batch(data, args):
keys = ['text', 'loss_mask', 'target', 'attention_mask']
datatype = torch.int64
# Broadcast data.
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
labels = data_b['target'].long()
decoder_tokens = decoder_shift_right(labels, args)
attention_mask = data_b['attention_mask'].long()
loss_mask = data_b['loss_mask'].float()
# Convert
if args.fp16:
attention_mask = attention_mask.half()
elif args.bf16:
attention_mask = attention_mask.bfloat16()
return tokens, decoder_tokens, labels, loss_mask, attention_mask
def get_masks_and_position_ids_glm(seq, mask_position, context_length): def get_masks_and_position_ids_glm(seq, mask_position, context_length):
tokens = seq.unsqueeze(0) tokens = seq.unsqueeze(0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment