diff --git a/arguments.py b/arguments.py
index dbdeea61f4199fc940f6988fa0d4897f66d561a2..ebefb2fb78ef1af4c57ff506c29654be8391926d 100755
--- a/arguments.py
+++ b/arguments.py
@@ -250,6 +250,7 @@ def add_glm_args(parser):
     group.add_argument('--random-position', action='store_true',
                        help="Use random start position to cover all the position embeddings")
     group.add_argument('--cloze-eval', action='store_true', help='Evaluation dataset with cloze task')
+    group.add_argument('--old-checkpoint', action='store_true', help="Loading the checkpoint from old libraray")
     return parser
 
 
diff --git a/config/model_glm_10B.sh b/config/model_glm_10B.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5ca0df0ed1a753f49baeae6818b0872268b35e91
--- /dev/null
+++ b/config/model_glm_10B.sh
@@ -0,0 +1,12 @@
+MODEL_TYPE="blocklm-10B"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --task-mask \
+            --num-layers 48 \
+            --hidden-size 4096 \
+            --num-attention-heads 64 \
+            --max-sequence-length 1025 \
+            --tokenizer-model-type gpt2 \
+            --tokenizer-type GPT2BPETokenizer \
+            --old-checkpoint \
+            --load ${CHECKPOINT_PATH}/blocklm-10b-1024"
\ No newline at end of file
diff --git a/config/model_glm_roberta_large.sh b/config/model_glm_roberta_large.sh
index 6fb7216ad23e370ab5cccaeadd626d19e5d84364..5c4eaf200fb2d6b91550db9112beae7f56e1ba98 100644
--- a/config/model_glm_roberta_large.sh
+++ b/config/model_glm_roberta_large.sh
@@ -7,4 +7,5 @@ MODEL_ARGS="--block-lm \
             --max-sequence-length 513 \
             --tokenizer-model-type roberta \
             --tokenizer-type GPT2BPETokenizer \
+            --old-checkpoint \
             --load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank"
\ No newline at end of file
diff --git a/inference_glm.py b/inference_glm.py
index 376773d3b5f13f3fcfc35fb51f0f4123851bb881..0456cc550df2065cc954d9509fd3f370fa1aa26d 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -21,7 +21,7 @@ from model.glm_model import GLMModel
 from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
 from tokenization import get_tokenizer
 from generation.sampling_strategies import BaseStrategy
-from generation.autoregressive_sampling import filling_sequence
+from generation.autoregressive_sampling import update_mems
 from generation.utils import timed_name, save_multiple_images, generate_continually
 
 
@@ -80,19 +80,19 @@ def read_context(tokenizer, args, output=None):
     return terminate_runs, raw_text, context_tokens_tensor, context_length
 
 
-def get_batch(context_tokens, device, args):
+def get_batch(context_tokens, args):
     tokens = context_tokens
     tokens = tokens.view(1, -1).contiguous()
-    tokens = tokens.to(device)
+    tokens = tokens.to('cuda')
 
     # Get the masks and postition ids.
     if args.block_lm:
-        attention_mask = torch.ones(1, 1, tokens.size(1), tokens.size(1), device=device, dtype=torch.long)
+        attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
         if args.fp16:
             attention_mask = attention_mask.half()
-        position_ids = torch.arange(tokens.size(1), device=device, dtype=torch.long)
+        position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
         if not args.no_block_position:
-            block_position_ids = torch.zeros(tokens.size(1), device=device, dtype=torch.long)
+            block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
             position_ids = torch.stack((position_ids, block_position_ids), dim=0)
         position_ids = position_ids.unsqueeze(0)
     else:
@@ -129,12 +129,8 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
     return logits
 
 
-def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None):
-    if not args.block_lm:
-        context_tokens, attention_mask, position_ids = get_batch(context_tokens, device, args)
-        tokens = torch.empty((args.num_beams, 0), device=context_tokens.device, dtype=torch.long)
-    else:
-        tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id)
+def sample_sequence(model, tokenizer, context_tokens, context_length, args, mems=None, end_tokens=None):
+    tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id)
     counter = 0
     if mems is None:
         mems = []
@@ -152,24 +148,22 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi
         beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device)
     last_beam_num = 1
     while counter < args.out_seq_length:
-        if counter == 0 and not args.block_lm:
-            next_token_logits, *mems = model(context_tokens, position_ids, attention_mask, *mems)
-        else:
-            if args.block_lm:
-                if args.no_block_position:
-                    position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter)
-                else:
-                    position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
-                    position_ids[:, 0] = context_length
-                    position_ids[:, 1] = counter + 1
-                attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device,
-                                                         dtype=torch.long)
+        if args.block_lm:
+            if args.no_block_position:
+                position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter)
             else:
-                position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
-                attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,
-                                                         device=context_tokens.device, dtype=torch.float)
-            last_token = tokens[:, -1:]
-            next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems)
+                position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
+                position_ids[:, 0] = context_length
+                position_ids[:, 1] = counter + 1
+            attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device,
+                                                     dtype=torch.long)
+        else:
+            position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
+            attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,
+                                                     device=context_tokens.device, dtype=torch.float)
+        last_token = tokens[:, -1:]
+        next_token_logits, *mem_kvs = model(last_token, position_ids, attention_mask, *mems)
+        mems = update_mems(mem_kvs, mems, max_memory_length=1000000)
         next_token_logits = next_token_logits[:, -1]
         if args.num_beams > 1:
             next_token_scores = F.log_softmax(next_token_logits, dim=-1)
@@ -228,7 +222,7 @@ def sample_sequence(model, tokenizer, context_tokens, context_length, args, devi
     return torch.cat((context_tokens, tokens), dim=1), mems
 
 
-def generate_samples(model, tokenizer, args, device):
+def generate_samples(model, tokenizer, args):
     model.eval()
     output_path = "./samples"
     if not os.path.exists(output_path):
@@ -237,14 +231,13 @@ def generate_samples(model, tokenizer, args, device):
     with torch.no_grad(), open(output_path, "w") as output:
         while True:
             torch.distributed.barrier(group=mpu.get_model_parallel_group())
-
             terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
             if terminate_runs == 1:
                 return
             start_time = time.time()
             if args.block_lm:
                 mems = []
-                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args)
+                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
                 mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
                 mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
                 end_tokens = [tokenizer.get_command('eop').Id, args.eod_token]
@@ -261,10 +254,8 @@ def generate_samples(model, tokenizer, args, device):
                         position = position_ids[0, mask_position].item()
                     else:
                         position = mask_position
-                    tokens, mems = sample_sequence(model, tokenizer, tokens, position,
-                                                   args, device, mems=mems, end_tokens=end_tokens)
-            else:
-                tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device)
+                    tokens, mems = sample_sequence(model, tokenizer, tokens, position, args, mems=mems,
+                                                   end_tokens=end_tokens)
             output_tokens_list = tokens.view(-1).contiguous()
             if mpu.get_model_parallel_rank() == 0:
                 os.system('clear')
@@ -290,7 +281,7 @@ def main(args):
     load_checkpoint(model, args)
     set_random_seed(args.seed)
     model.eval()
-    generate_samples(model, tokenizer, args, torch.cuda.current_device())
+    generate_samples(model, tokenizer, args)
 
 
 if __name__ == "__main__":
diff --git a/mpu/transformer.py b/mpu/transformer.py
index 975167234ab05efb85dbdcabf0b8432313b09529..bad1a9ce47aa4dfdb167ff090cbdb3055ec44703 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -52,9 +52,9 @@ def standard_attention(query_layer, key_layer, value_layer, attention_mask,
     if log_attention_weights is not None:
         attention_scores += log_attention_weights
     
-    if attention_mask.shape[-2] > 1: # if auto-regressive, skip
-        attention_scores = torch.mul(attention_scores, attention_mask) - \
-                    10000.0 * (1.0 - attention_mask)
+    # if attention_mask.shape[-2] > 1: # if auto-regressive, skip
+    #     attention_scores = torch.mul(attention_scores, attention_mask) - \
+    #                 10000.0 * (1.0 - attention_mask)
 
     attention_probs = F.softmax(attention_scores, dim=-1)
 
diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh
index af5a59a418c6c9ebe56e3e0990bcfef74394aebf..30e6d9eb1a367b61acb782012cd90889fb1ef2f6 100644
--- a/scripts/generate_glm.sh
+++ b/scripts/generate_glm.sh
@@ -1,5 +1,5 @@
 #!/bin/bash
-CHECKPOINT_PATH=./checkpoints
+CHECKPOINT_PATH=/dataset/fd5061f6/english_data/checkpoints
 
 source $1
 
diff --git a/training/model_io.py b/training/model_io.py
index df6b7524f6ac8334243af401f4147c9a34dfffae..8de74ea100da51f0212a53b64bf2ebd306cf786f 100644
--- a/training/model_io.py
+++ b/training/model_io.py
@@ -28,8 +28,12 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
         d += '_zero_dp_rank_{}'.format(dp_rank)
     return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
 
-def get_checkpoint_tracker_filename(checkpoints_path):
-    return os.path.join(checkpoints_path, 'latest')
+def get_checkpoint_tracker_filename(checkpoints_path, old_checkpoint=False):
+    if old_checkpoint:
+        return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
+    else:
+        return os.path.join(checkpoints_path, 'latest')
+
 
 def save_checkpoint(iteration, model, optimizer,
                     lr_scheduler, args):
@@ -85,7 +89,7 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save
 
 def get_checkpoint_iteration(args):
     # Read the tracker file and set the iteration.
-    tracker_filename = get_checkpoint_tracker_filename(args.load)
+    tracker_filename = get_checkpoint_tracker_filename(args.load, old_checkpoint=args.old_checkpoint)
     if not os.path.isfile(tracker_filename):
         print_rank_0('WARNING: could not find the metadata file {} '.format(
             tracker_filename))
@@ -126,7 +130,12 @@ def load_checkpoint(model, args):
         module = model.module
     else: # inference without deepspeed
         module = model
-        
+
+    # Process the checkpoint for GLM
+    if args.block_lm and args.old_checkpoint:
+        sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight']
+        del sd['module']['word_embeddings.weight']
+
     # only load module, other hyperparameters are just for recording.
     missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
     if len(unexpected_keys) > 0: