diff --git a/config/model_glm_10B.sh b/config/model_glm_10B.sh
index 5ca0df0ed1a753f49baeae6818b0872268b35e91..4253bf25f786ed18459e62f4c9d56c7793103ced 100644
--- a/config/model_glm_10B.sh
+++ b/config/model_glm_10B.sh
@@ -8,5 +8,4 @@ MODEL_ARGS="--block-lm \
             --max-sequence-length 1025 \
             --tokenizer-model-type gpt2 \
             --tokenizer-type GPT2BPETokenizer \
-            --old-checkpoint \
             --load ${CHECKPOINT_PATH}/blocklm-10b-1024"
\ No newline at end of file
diff --git a/config/model_glm_10B_chinese.sh b/config/model_glm_10B_chinese.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f01347008d7959bbff14b965e48055b7f598567a
--- /dev/null
+++ b/config/model_glm_10B_chinese.sh
@@ -0,0 +1,11 @@
+MODEL_TYPE="blocklm-10B-chinese"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --task-mask \
+            --num-layers 48 \
+            --hidden-size 4096 \
+            --num-attention-heads 64 \
+            --max-sequence-length 1025 \
+            --tokenizer-type glm_ChineseSPTokenizer \
+            --tokenizer-model-type glm-10b \
+            --load /dataset/fd5061f6/english_data/checkpoints/blocklm-10b-chinese07-08-15-28"
\ No newline at end of file
diff --git a/config/model_glm_large_chinese.sh b/config/model_glm_large_chinese.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c1067636fcede65b37e0eb4b2881d4d868eda91c
--- /dev/null
+++ b/config/model_glm_large_chinese.sh
@@ -0,0 +1,11 @@
+MODEL_TYPE="blocklm-large-chinese"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --task-mask \
+            --num-layers 24 \
+            --hidden-size 1024 \
+            --num-attention-heads 16 \
+            --max-sequence-length 1025 \
+            --tokenizer-type glm_ChineseSPTokenizer \
+            --tokenizer-model-type glm-large \
+            --load /dataset/fd5061f6/english_data/checkpoints/blocklm-large-chinese"
\ No newline at end of file
diff --git a/config/model_glm_roberta_large.sh b/config/model_glm_roberta_large.sh
index 5c4eaf200fb2d6b91550db9112beae7f56e1ba98..6fb7216ad23e370ab5cccaeadd626d19e5d84364 100644
--- a/config/model_glm_roberta_large.sh
+++ b/config/model_glm_roberta_large.sh
@@ -7,5 +7,4 @@ MODEL_ARGS="--block-lm \
             --max-sequence-length 513 \
             --tokenizer-model-type roberta \
             --tokenizer-type GPT2BPETokenizer \
-            --old-checkpoint \
             --load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank"
\ No newline at end of file
diff --git a/inference_glm.py b/inference_glm.py
index 66dabaeb1258064a74802742e819d083cddb5c2a..66c33f4c2f9220d7907bd4d66d879df30779f501 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -27,6 +27,7 @@ from generation.autoregressive_sampling import filling_sequence
 from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
 from generation.utils import timed_name, generate_continually
 
+
 def get_masks_and_position_ids_glm(seq, mask_position, context_length):
     tokens = seq.unsqueeze(0)
 
@@ -45,6 +46,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
 
 
 def main(args):
+    args.do_train = False
     initialize_distributed(args)
     tokenizer = prepare_tokenizer(args)
     # build model 
@@ -77,14 +79,14 @@ def main(args):
             raise ValueError('text too long.')
         
         # find mask tokens positions
-        mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
-        mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
-        mask_positions = []
-        context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device)
-        for token in mask_tokens:
-            mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
-        mask_positions.sort()
-        
+        # mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
+        # mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
+        # mask_positions = []
+        # context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device)
+        # for token in mask_tokens:
+        #     mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
+        # mask_positions.sort()
+
         # generation
         mbz = args.max_inference_batch_size
         assert args.batch_size < mbz or args.batch_size % mbz == 0
@@ -107,7 +109,9 @@ def main(args):
             get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq))
             output_list = []
             for tim in range(max(args.batch_size // mbz, 1)):
-                input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device)
+                input_seq = torch.cuda.LongTensor(
+                    seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1),
+                    device=args.device)
                 output, _mems = filling_sequence(model, input_seq,
                         batch_size=min(args.batch_size, mbz),
                         strategy=strategy,
@@ -116,7 +120,7 @@ def main(args):
                         ) # we don't use mems, fill back
                 if isinstance(output, torch.Tensor): # different strategies
                     output = list(output)
-                
+
                 output_list.extend(output)
 
             # clip -1s and fill back generated things into seq
@@ -126,10 +130,10 @@ def main(args):
                     unfinished = output.index(-1)
                 except ValueError:
                     unfinished = len(output)
-                if output[unfinished-1] in end_tokens:
+                if output[unfinished - 1] in end_tokens:
                     unfinished -= 1
                 bog = output.index(tokenizer.get_command('sop').Id)
-                output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog]
+                output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog]
 
         # decoding
         txts = []
@@ -147,11 +151,12 @@ def main(args):
         with open(full_path, 'w') as fout:
             for txt in txts:
                 fout.write(txt + '\n')
-        os.chmod(full_path, stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
+        os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
 
     os.makedirs(args.output_path, exist_ok=True)
     generate_continually(process, args.input_source)
 
+
 if __name__ == "__main__":
     py_parser = argparse.ArgumentParser(add_help=False)
 
@@ -160,4 +165,4 @@ if __name__ == "__main__":
     args = argparse.Namespace(**vars(args), **vars(known))
     
     with torch.no_grad():
-        main(args)
\ No newline at end of file
+        main(args)
diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh
index e007900381dd94deb395f94b401ab46a81aa789e..41ab3f3ea6c303dddd801c7bab08b411dfb927d1 100755
--- a/scripts/generate_glm.sh
+++ b/scripts/generate_glm.sh
@@ -1,5 +1,5 @@
 #!/bin/bash
-CHECKPOINT_PATH=pretrained/glm
+#CHECKPOINT_PATH=/workspace/dm/SwissArmyTransformer/pretrained/glm
 
 # MODEL_ARGS="--block-lm \
 #             --cloze-eval \
@@ -11,19 +11,20 @@ CHECKPOINT_PATH=pretrained/glm
 #             --tokenizer-type glm_GPT2BPETokenizer \
 #             --load ${CHECKPOINT_PATH}/glm-roberta-large-blank"
 
-MODEL_TYPE="blocklm-10B"
-MODEL_ARGS="--block-lm \
-            --cloze-eval \
-            --task-mask \
-            --num-layers 48 \
-            --hidden-size 4096 \
-            --num-attention-heads 64 \
-            --max-sequence-length 1025 \
-            --tokenizer-model-type gpt2 \
-            --tokenizer-type glm_GPT2BPETokenizer \
-            --old-checkpoint \
-            --load ${CHECKPOINT_PATH}/glm-en-10b"
+#MODEL_TYPE="blocklm-10B"
+#MODEL_ARGS="--block-lm \
+#            --cloze-eval \
+#            --task-mask \
+#            --num-layers 48 \
+#            --hidden-size 4096 \
+#            --num-attention-heads 64 \
+#            --max-sequence-length 1025 \
+#            --tokenizer-model-type gpt2 \
+#            --tokenizer-type glm_GPT2BPETokenizer \
+#            --old-checkpoint \
+#            --load ${CHECKPOINT_PATH}/glm-en-10b"
 
+source $1
 MPSIZE=1
 MAXSEQLEN=512
 MASTER_PORT=$(shuf -n 1 -i 10000-65535)
@@ -52,5 +53,5 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
        --top_k $TOPK \
        --output-path glm_text \
        --batch-size 1 \
-       --out-seq-length 100 \
+       --out-seq-length 512 \
        --mode inference
diff --git a/tokenization/__init__.py b/tokenization/__init__.py
index 427d98eb5e1dc4494b00f51fd9913b960874bd82..4be0ef41c2d2ed0eef7139df8b6aa2b05a6a397c 100644
--- a/tokenization/__init__.py
+++ b/tokenization/__init__.py
@@ -34,7 +34,7 @@ def get_tokenizer(args=None):
                 get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
             elif args.tokenizer_type == "glm_ChineseSPTokenizer":
                 from .text import ChineseSPTokenizer
-                get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs)
+                get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
diff --git a/tokenization/text/tokenization.py b/tokenization/text/tokenization.py
index 51d2009152f5ec796e92fa47b8ef9a84f9496d08..ae0cf5da8aacf731c42293dea2db1d36c4f60e77 100644
--- a/tokenization/text/tokenization.py
+++ b/tokenization/text/tokenization.py
@@ -1157,29 +1157,57 @@ class GPT2BPETokenizer(Tokenizer):
 
 
 class ChineseSPTokenizer(Tokenizer):
-    def __init__(self, add_block_symbols=False, **kwargs):
+    def __init__(self, model_type_or_path, add_block_symbols=False, add_task_mask=False, add_decoder_mask=False,
+                 **kwargs):
         self.text_tokenizer = sp_tokenizer.from_pretrained()
 
-        self.num_command_tokens = 2
+        self.num_command_tokens = 0
         self.num_text_tokens = self.text_tokenizer.sp.vocab_size()
-        self.num_tokens = self.num_text_tokens + 1
+        self.num_tokens = self.num_text_tokens
         self.num_type_tokens = 2
 
         self._command_tokens = [
             CommandToken('pad', '<|endoftext|>', self.num_text_tokens),
             CommandToken('eos', '<|endoftext|>', self.num_text_tokens),
+            CommandToken('sep', '[SEP]', self.num_text_tokens + 1),
+            CommandToken('ENC', '[CLS]', self.num_text_tokens + 2),
+            CommandToken('MASK', '[MASK]', self.num_text_tokens + 3, lstrip=True),
+            CommandToken('unk', '[UNK]', self.num_text_tokens + 4)
         ]
+        self.num_tokens += 5
+        self.num_command_tokens += 6
         if add_block_symbols:
             self._command_tokens.extend([
-                CommandToken('sop', '<|startofpiece|>', self.num_text_tokens + 1),
-                CommandToken('eop', '<|endofpiece|>', self.num_text_tokens + 2)
+                CommandToken('sop', '<|startofpiece|>', self.num_tokens + 1),
+                CommandToken('eop', '<|endofpiece|>', self.num_tokens + 2)
             ])
-            self.num_tokens += 2
+            if model_type_or_path == 'glm-large':
+                self.num_tokens += 3
+            else:
+                self.num_tokens += 2
             self.num_command_tokens += 2
+            if add_task_mask:
+                if model_type_or_path == 'glm-large':
+                    self._command_tokens.extend([
+                        CommandToken('sMASK', '[sMASK]', self.num_tokens, lstrip=True),
+                        CommandToken('gMASK', '[gMASK]', self.num_tokens + 1, lstrip=True)
+                    ])
+                else:
+                    self._command_tokens.extend([
+                        CommandToken('gMASK', '[gMASK]', self.num_tokens, lstrip=True),
+                        CommandToken('sMASK', '[sMASK]', self.num_tokens + 1, lstrip=True)
+                    ])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+            if add_decoder_mask:
+                self._command_tokens.extend([
+                    CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)
+                ])
+                self.num_tokens += 1
+                self.num_command_tokens += 1
         self.command_name_map = {tok.name: tok for tok in self._command_tokens}
         self.command_token_map = {tok.token: tok for tok in self._command_tokens}
         self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
-
         self.type_tokens = [
             TypeToken('str0', '<str0>', 0),
             TypeToken('str1', '<str1>', 1),
@@ -1224,7 +1252,7 @@ class ChineseSPTokenizer(Tokenizer):
         elif Id in self.type_id_map:
             return self.type_id_map[Id].token
         else:
-            return self.text_tokenizer.convert_id_to_token(Id)
+            return self.text_tokenizer.convert_id_to_token(int(Id))
 
     def TokenToId(self, token, type_token=False):
         if isinstance(token, (TypeToken, CommandToken)):
@@ -1238,13 +1266,22 @@ class ChineseSPTokenizer(Tokenizer):
             return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
         if isinstance(Ids, Tokenization):
             Ids = Ids.tokenization
-        try:
-            first_eos = Ids.index(self.get_command('eos').Id)
-            eos_count = len(Ids) - first_eos
-            Ids = Ids[:first_eos]
-        except ValueError:
-            eos_count = 0
-        return " ".join((self.text_tokenizer.decode(Ids), *(['<|endoftext|>'] * eos_count)))
+        Ids = list(map(int, Ids))
+        pieces = []
+        last = 0
+        for i, token_id in enumerate(Ids):
+            if token_id in self.command_id_map:
+                pieces.append(Ids[last: i])
+                pieces.append(token_id)
+                last = i + 1
+        pieces.append(Ids[last:])
+        text = ""
+        for piece in pieces:
+            if isinstance(piece, int):
+                text += self.command_id_map[piece].token
+            elif piece:
+                text += self.text_tokenizer.decode(piece)
+        return text
 
     def DecodeTokens(self, Tokens, type_token=False):
         if type_token:
diff --git a/training/model_io.py b/training/model_io.py
index 18151863ae66a0d6a5e640f71c875d8a61cea92b..ddb6d26aac0bb6b2dc56d3717ad7ae16210a7c5e 100644
--- a/training/model_io.py
+++ b/training/model_io.py
@@ -1,4 +1,3 @@
-
 # -*- encoding: utf-8 -*-
 '''
 @File    :   model_io.py
@@ -18,6 +17,7 @@ import numpy as np
 import mpu
 from .utils import print_rank_0
 
+
 def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
     if release:
         d = 'release'
@@ -28,6 +28,7 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
         d += '_zero_dp_rank_{}'.format(dp_rank)
     return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
 
+
 def get_checkpoint_tracker_filename(checkpoints_path, old_checkpoint=False):
     if old_checkpoint:
         return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
@@ -70,9 +71,9 @@ def save_ds_checkpoint(iteration, model, lr_scheduler, args):
         sd['cuda_rng_state'] = torch.cuda.get_rng_state()
         sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
     save_ds_checkpoint_no_optim(model, args.save, str(iteration), client_state=sd)
-    
+
+
 def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save_latest=True):
-    
     os.makedirs(save_dir, exist_ok=True)
     # Ensure tag is a string
     tag = str(tag)
@@ -112,6 +113,7 @@ def get_checkpoint_iteration(args):
 
     return iteration, release, True
 
+
 def load_checkpoint(model, args):
     """Load a model checkpoint."""
 
@@ -131,10 +133,17 @@ def load_checkpoint(model, args):
     else: # inference without deepspeed
         module = model
 
+    # sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight']
+    # del sd['module']['word_embeddings.weight']
+    # sd['module']['mixins.block_position_embedding.block_position_embeddings.weight'] = sd['module'][
+    #     'transformer.block_position_embeddings.weight']
+    # del sd['module']['transformer.block_position_embeddings.weight']
+
     # only load module, other hyperparameters are just for recording.
     missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
     if len(unexpected_keys) > 0:
-        print_rank_0(f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
+        print_rank_0(
+            f'Will continue but found unexpected_keys! Check whether you are loading correct checkpoints: {unexpected_keys}.')
     if len(missing_keys) > 0:
         if not args.do_train:
             raise ValueError(f'Missing keys for inference: {missing_keys}.')