diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py
index 38a008a8e8ead6ff4a0f4fda3938333ffebaf9df..b902f248e0c6ebe6dfccba3805d4112923eff6fa 100644
--- a/examples/t5/inference_t5.py
+++ b/examples/t5/inference_t5.py
@@ -29,6 +29,34 @@ from SwissArmyTransformer.generation.utils import timed_name, generate_continual
 from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer
 
 
+def decoder_shift_right(input_ids, args):
+    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+    shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
+    shifted_input_ids[..., 0] = args.decoder_start_token_id
+    return shifted_input_ids
+
+
+def get_batch(data, args):
+    keys = ['text', 'loss_mask', 'target', 'attention_mask']
+    datatype = torch.int64
+
+    # Broadcast data.
+    data_b = mpu.broadcast_data(keys, data, datatype)
+    # Unpack.
+    tokens = data_b['text'].long()
+    labels = data_b['target'].long()
+    decoder_tokens = decoder_shift_right(labels, args)
+    attention_mask = data_b['attention_mask'].long()
+    loss_mask = data_b['loss_mask'].float()
+
+    # Convert
+    if args.fp16:
+        attention_mask = attention_mask.half()
+    elif args.bf16:
+        attention_mask = attention_mask.bfloat16()
+    return tokens, decoder_tokens, labels, loss_mask, attention_mask
+
+
 def get_masks_and_position_ids_glm(seq, mask_position, context_length):
     tokens = seq.unsqueeze(0)