diff --git a/SwissArmyTransformer/tokenization/__init__.py b/SwissArmyTransformer/tokenization/__init__.py
index 7e0485274deeae3f095456f73a0d9c842a026e46..ffdf5c354cfe6985d1aff9415079ab1c02dcacd7 100644
--- a/SwissArmyTransformer/tokenization/__init__.py
+++ b/SwissArmyTransformer/tokenization/__init__.py
@@ -29,7 +29,8 @@ def _export_vocab_size_to_args(args, original_num_tokens):
     print_rank_0('> padded vocab (size: {}) with {} dummy '
                  'tokens (new size: {})'.format(
         before, after - before, after))
-    args.vocab_size = after
+    if not args.vocab_size:
+        args.vocab_size = after
     print_rank_0("prepare tokenizer done")
     return tokenizer
 
@@ -63,6 +64,10 @@ def get_tokenizer(args=None, outer_tokenizer=None):
             elif args.tokenizer_type == "glm_ChineseSPTokenizer":
                 from .glm import ChineseSPTokenizer
                 get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
+        elif args.tokenizer_type.startswith('hf'):
+            from .hf_tokenizer import HFT5Tokenizer
+            if args.tokenizer_type == "hf_T5Tokenizer":
+                get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
diff --git a/SwissArmyTransformer/tokenization/glm/tokenization.py b/SwissArmyTransformer/tokenization/glm/tokenization.py
index 67be818a0f1d42b0d35696c24577b5e5391ff3b4..9b9a8abc0202b46a95efc8a8338e90e0da9fd60f 100644
--- a/SwissArmyTransformer/tokenization/glm/tokenization.py
+++ b/SwissArmyTransformer/tokenization/glm/tokenization.py
@@ -312,11 +312,11 @@ class Tokenizer(object):
         tokenization.tokenization = [self.IdToToken(idx) for idx in tokenization.tokenization]
         return tokenization
 
-    def IdToToken(self, Id):
+    def IdToToken(self, idx):
         """convert Id to token accounting for command tokens"""
-        if isinstance(Id, CommandToken):
-            return Id.token
-        return self.tokens[Id]
+        if isinstance(idx, CommandToken):
+            return idx.token
+        return self.tokens[idx]
 
     def TokenToId(self, token):
         """convert token to Id accounting for command tokens"""
@@ -324,16 +324,16 @@ class Tokenizer(object):
             return token.Id
         return self.vocab[token]
 
-    def DecodeIds(self, Ids):
+    def DecodeIds(self, ids):
         """
         convert Ids to tokens accounting for command tokens, tokens
         are joined and returned as a string.
         """
         rtn_strs = []
         current_str = []
-        if isinstance(Ids, Tokenization):
-            Ids = Ids.tokenization
-        for Id in Ids:
+        if isinstance(ids, Tokenization):
+            ids = ids.tokenization
+        for Id in ids:
             if isinstance(Id, CommandToken):
                 rtn_strs.append(self._decode(current_str))
                 current_str = []
@@ -353,11 +353,11 @@ class Tokenizer(object):
         output = self.clean_up_tokenization(output)
         return output
 
-    def DecodeTokens(self, Tokens):
+    def DecodeTokens(self, tokens):
         """
         convert tokens to a string accounting for command and type tokens.
         """
-        Ids = [self.TokenToId(token) for token in Tokens]
+        Ids = [self.TokenToId(token) for token in tokens]
         return self.DecodeIds(Ids)
 
 
diff --git a/SwissArmyTransformer/tokenization/hf_tokenizer.py b/SwissArmyTransformer/tokenization/hf_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d67197662cf200acbda1299bce1e9199c1e34862
--- /dev/null
+++ b/SwissArmyTransformer/tokenization/hf_tokenizer.py
@@ -0,0 +1,65 @@
+from transformers import T5Tokenizer
+from .glm.tokenization import Tokenization, CommandToken
+
+
+class HFTokenizer:
+    def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None):
+        self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir)
+        self.num_tokens = len(self.text_tokenizer)
+        self._command_tokens = []
+        self.command_name_map = {}
+        self.command_token_map = {}
+        self.command_id_map = {}
+
+    @property
+    def command_tokens(self):
+        return self._command_tokens
+
+    @command_tokens.setter
+    def command_tokens(self, command_tokens):
+        self._command_tokens = command_tokens
+        self.command_name_map = {tok.name: tok for tok in self.command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self.command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self.command_tokens}
+
+    def get_command(self, name):
+        """get command token corresponding to `name`"""
+        return self.command_name_map[name]
+
+    def EncodeAsIds(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        ids = self.text_tokenizer.encode(processed_text, add_special_tokens=False)
+        tokenization = Tokenization(ids, processed_text, text)
+        return tokenization
+
+    def DecodeIds(self, ids):
+        if isinstance(ids, Tokenization):
+            ids = ids.tokenization
+        return self.text_tokenizer.decode(ids)
+
+    def DecodeTokens(self, tokens):
+        return self.text_tokenizer.convert_tokens_to_string(tokens)
+
+    def IdToToken(self, Id):
+        if isinstance(Id, CommandToken):
+            return Id.token
+        return self.text_tokenizer.convert_ids_to_tokens(Id)
+
+    def TokenToId(self, token):
+        if isinstance(token, CommandToken):
+            return token.Id
+        return self.text_tokenizer.convert_tokens_to_ids(token)
+
+
+class HFT5Tokenizer(HFTokenizer):
+    def __init__(self, model_type_or_path=None, cache_dir=None):
+        super().__init__(T5Tokenizer, model_type_or_path=model_type_or_path, cache_dir=cache_dir)
+        command_tokens = [
+            CommandToken('eos', '</s>', self.TokenToId("</s>")),
+            CommandToken('pad', '<pad>', self.TokenToId("<pad>")),
+        ]
+        for i in range(100):
+            command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>')))
+        self.command_tokens = command_tokens
diff --git a/examples/t5/config/model_t5_large.sh b/examples/t5/config/model_t5_large.sh
index ce4e27741ea749955534d7049b96b04f8b91e0ef..0c20d559447a87030f5d07a701d7a125b4f79c97 100644
--- a/examples/t5/config/model_t5_large.sh
+++ b/examples/t5/config/model_t5_large.sh
@@ -10,6 +10,6 @@ MODEL_ARGS="--block-lm \
             --max-sequence-length 513 \
             --relative-attention-num-buckets 32 \
             --layernorm-epsilon 1e-6 \
-            --tokenizer-model-type roberta \
-            --tokenizer-type glm_GPT2BPETokenizer \
+            --tokenizer-type hf_T5Tokenizer \
+            --tokenizer-model-type t5-large \
             --load ${CHECKPOINT_PATH}/glm-large-en-blank"
\ No newline at end of file
diff --git a/inference_t5.py b/inference_t5.py
index 8e938f851b0757c6cb178d0cdbd49dc5da46cdde..c2af4e74c9cb30a541a03c04176f2891f27a53fc 100644
--- a/inference_t5.py
+++ b/inference_t5.py
@@ -48,7 +48,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
 def main(args):
     args.do_train = False
     initialize_distributed(args)
-    # tokenizer = get_tokenizer(args)
+    tokenizer = get_tokenizer(args)
     # build model
     model = T5Model(args)
     if args.fp16:
@@ -60,9 +60,16 @@ def main(args):
         torch.load("/dataset/fd5061f6/yanan/huggingface_models/t5-large/model_states.pt")["module"])
     from SwissArmyTransformer.model.encoder_decoder_model import EncoderFinalMixin
     model.eval()
-    input_ids = torch.cuda.LongTensor([[37, 32099, 10681, 16, 32098, 2447, 1]])
-    decoder_input_ids = torch.cuda.LongTensor([[32099, 5295, 1782, 32098, 8, 32097, 1]])
-    output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+    input_ids = tokenizer.EncodeAsIds("The <extra_id_0> walks in <extra_id_1> park").tokenization
+    input_ids = input_ids + [tokenizer.get_command("eos").Id]
+    input_ids = torch.cuda.LongTensor([input_ids])
+    # input_ids = torch.cuda.LongTensor([[37, 32099, 10681, 16, 32098, 2447, 1]])
+    decoder_input_ids = tokenizer.EncodeAsIds('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>').tokenization
+    decoder_input_ids = decoder_input_ids + [tokenizer.get_command("eos").Id]
+    decoder_input_ids = torch.cuda.LongTensor([decoder_input_ids])
+    # decoder_input_ids = torch.cuda.LongTensor([[32099, 5295, 1782, 32098, 8, 32097, 1]])
+    breakpoint()
+    output = model(enc_input_ids=input_ids, dec_input_ids=decoder_input_ids)
     print(output)
     end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
     # define function for each query