diff --git a/examples/t5/config/config_t5_large.json b/examples/t5/config/config_t5_large.json
new file mode 100644
index 0000000000000000000000000000000000000000..25d7bf7ac71485cda4fba8500944712a06b775ab
--- /dev/null
+++ b/examples/t5/config/config_t5_large.json
@@ -0,0 +1,34 @@
+{
+  "train_micro_batch_size_per_gpu": 16,
+  "gradient_accumulation_steps": 1,
+  "steps_per_print": 100,
+  "gradient_clipping": 1.0,
+  "zero_optimization": {
+    "stage": 2,
+    "contiguous_gradients": false,
+    "overlap_comm": true,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 50000000,
+    "allgather_bucket_size": 500000000
+  },
+  "bfloat16": {
+    "enabled": true
+  },
+  "optimizer": {
+    "type": "Adam",
+    "params": {
+      "lr": 0.0002,
+      "weight_decay": 0.1,
+      "betas": [
+        0.9,
+        0.98
+      ],
+      "eps": 1e-6
+    }
+  },
+  "activation_checkpointing": {
+    "partition_activations": false,
+    "contiguous_memory_optimization": false
+  },
+  "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/examples/t5/inference_t5.py b/examples/t5/inference_t5.py
index c2af4e74c9cb30a541a03c04176f2891f27a53fc..38a008a8e8ead6ff4a0f4fda3938333ffebaf9df 100644
--- a/examples/t5/inference_t5.py
+++ b/examples/t5/inference_t5.py
@@ -26,6 +26,7 @@ from SwissArmyTransformer.model.mixins import CachedAutoregressiveMixin
 from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence, evaluate_perplexity
 from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
 from SwissArmyTransformer.generation.utils import timed_name, generate_continually
+from SwissArmyTransformer.training.deepspeed_training import setup_model_and_optimizer
 
 
 def get_masks_and_position_ids_glm(seq, mask_position, context_length):
@@ -49,28 +50,37 @@ def main(args):
     args.do_train = False
     initialize_distributed(args)
     tokenizer = get_tokenizer(args)
-    # build model
-    model = T5Model(args)
-    if args.fp16:
-        model = model.half()
-    model = model.to(args.device)
     # load_checkpoint(model, args)
     set_random_seed(args.seed)
-    missing_keys, unexpected_keys = model.load_state_dict(
+
+    # Model, optimizer, and learning rate.
+    model_cls = T5Model
+    model, optimizer = setup_model_and_optimizer(args, model_cls=model_cls)
+
+    missing_keys, unexpected_keys = model.module.load_state_dict(
         torch.load("/dataset/fd5061f6/yanan/huggingface_models/t5-large/model_states.pt")["module"])
-    from SwissArmyTransformer.model.encoder_decoder_model import EncoderFinalMixin
+    optimizer.refresh_fp32_params()
     model.eval()
     input_ids = tokenizer.EncodeAsIds("The <extra_id_0> walks in <extra_id_1> park").tokenization
     input_ids = input_ids + [tokenizer.get_command("eos").Id]
-    input_ids = torch.cuda.LongTensor([input_ids])
-    # input_ids = torch.cuda.LongTensor([[37, 32099, 10681, 16, 32098, 2447, 1]])
+    input_ids = torch.LongTensor([input_ids])
     decoder_input_ids = tokenizer.EncodeAsIds('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>').tokenization
     decoder_input_ids = decoder_input_ids + [tokenizer.get_command("eos").Id]
-    decoder_input_ids = torch.cuda.LongTensor([decoder_input_ids])
-    # decoder_input_ids = torch.cuda.LongTensor([[32099, 5295, 1782, 32098, 8, 32097, 1]])
+    decoder_input_ids = torch.LongTensor([decoder_input_ids])
+    data = {'text': input_ids, 'loss_mask': input_ids.new_ones(input_ids.shape), 'target': decoder_input_ids,
+            'attention_mask': input_ids.new_ones(input_ids.shape)}
+    tokens, decoder_tokens, labels, loss_mask, attention_mask = get_batch(data, args)
+    encoder_outputs, logits, *_ = model(enc_input_ids=tokens, dec_input_ids=decoder_tokens,
+                                        enc_attention_mask=attention_mask)
+    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), labels)
+    loss_mask = loss_mask.view(-1)
+    loss = torch.sum(losses.view(-1) * loss_mask)
+    if loss_mask.sum().item() > 0:
+        loss = loss / loss_mask.sum()
+    loss.backward()
+
     breakpoint()
-    output = model(enc_input_ids=input_ids, dec_input_ids=decoder_input_ids)
-    print(output)
+
     end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
     # define function for each query
     if args.sampling_strategy == 'BaseStrategy':
diff --git a/examples/t5/scripts/generate_t5.sh b/examples/t5/scripts/generate_t5.sh
index c01f7723655d62f191b4229f86d23c4ca93893b9..c4bc602d8babfca6a51b29c99e9ccfdd1d209558 100644
--- a/examples/t5/scripts/generate_t5.sh
+++ b/examples/t5/scripts/generate_t5.sh
@@ -15,10 +15,11 @@ TOPP=0
 script_path=$(realpath $0)
 script_dir=$(dirname $script_path)
 
-config_json="$script_dir/ds_config.json"
+config_json="$script_dir/config_t5_large.json"
 
 python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_t5.py \
        --deepspeed \
+       --deepspeed-config ${config_json} \
        --mode inference \
        --model-parallel-size $MPSIZE \
        $MODEL_ARGS \
diff --git a/examples/t5/test_t5.py b/examples/t5/test_t5.py
index 805f22811c7dfdcd198203c35c1c816680b1a834..692ed918f39ddebbcbc23cbf194e0999e2a44aa6 100644
--- a/examples/t5/test_t5.py
+++ b/examples/t5/test_t5.py
@@ -1,7 +1,10 @@
 from transformers import T5Model, T5ForConditionalGeneration, T5Tokenizer
 tokenizer = T5Tokenizer.from_pretrained("t5-large")
-model = T5ForConditionalGeneration.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
-input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
-decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
-output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+model = T5Model.from_pretrained("/dataset/fd5061f6/yanan/huggingface_models/t5-large")
+model = model.to('cuda')
+model.eval()
+input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids.to('cuda')
+decoder_input_ids = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids.to('cuda')
+output = model(input_ids=input_ids, labels=decoder_input_ids)
+output.loss.backward()
 breakpoint()
\ No newline at end of file