diff --git a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py
index b30926294ce26e4a890b51b89896b82de88b285a..9b996fb645dfbe485b65aa0f588e89ef65f4449a 100644
--- a/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py
+++ b/SwissArmyTransformer/generation/sampling_strategies/beam_search_strategy.py
@@ -11,7 +11,7 @@ import torch
 import torch.nn.functional as F
 
 class BeamSearchStrategy:
-    def __init__(self, num_beams, length_penalty=1., return_only_end=False,
+    def __init__(self, num_beams, length_penalty=1., consider_end=False,
                 end_tokens=[], invalid_slices=[], no_repeat_ngram_size=0, min_tgt_length=0):
         self.num_beams = num_beams
         self.length_penalty = length_penalty
@@ -19,7 +19,7 @@ class BeamSearchStrategy:
         self.ngram = no_repeat_ngram_size
         self.min_tgt_length = min_tgt_length
         self.invalid_slices = invalid_slices
-        self.return_only_end = return_only_end
+        self.consider_end = consider_end
         self._init_cache()
 
     def _init_cache(self):
@@ -34,10 +34,10 @@ class BeamSearchStrategy:
         for i in range(len(self.end_beams), -1, -1):
             if i == 0 or score < self.end_beams_penalized_scores[i-1]:
                 break
-        self.num_beams.insert(i, beam)
+        self.end_beams.insert(i, beam)
         self.end_beams_penalized_scores.insert(i, score)
 
-        self.num_beams = self.num_beams[:self.num_beams]
+        self.end_beams = self.end_beams[:self.num_beams]
         self.end_beams_penalized_scores = self.end_beams_penalized_scores[:self.num_beams]
 
     def forward(self, logits, tokens, mems):
@@ -52,11 +52,14 @@ class BeamSearchStrategy:
         if self.ngram > 0 and seq_len > self.ngram:
             for i in range(batch_size):
                 ngram_prefix = tokens[i, -(self.ngram-1):].tolist() # TODO ngram=1
-                for banned_index in self.cached_beam_ngram_bans.get(tuple(ngram_prefix), default=[]):
+                for banned_index in self.cached_beam_ngram_bans[i].get(tuple(ngram_prefix), []):
                     logits[i, banned_index] = -65504
         
         next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
-        next_token_scores = next_token_scores + self.cached_beam_scores[:, None].expand_as(next_token_scores)
+        prev_scores = self.cached_beam_scores
+        if isinstance(self.cached_beam_scores, torch.Tensor):
+            prev_scores = prev_scores[:, None].expand_as(next_token_scores)
+        next_token_scores = next_token_scores + prev_scores
         
         next_token_scores = next_token_scores.view(batch_size * vocab_size)
 
@@ -70,12 +73,14 @@ class BeamSearchStrategy:
         next_tokens = next_tokens % vocab_size
 
         # select out end beams or continue beams
+        if mems.shape[1] < batch_size:
+            mems = mems.expand(-1, batch_size, -1, -1)
         beam_continue = []
         scores_continue = []
         bans_continue = []
         mems_contiue = []
         for i in range(len(next_tokens)):
-            beam = torch.cat(tokens[next_indices[i]], next_tokens[i:i+1])
+            beam = torch.cat((tokens[next_indices[i]], next_tokens[i:i+1]))
             if int(next_tokens[i]) in self.end_tokens:
                 self._add_end_beams(next_token_scores[i], beam)
             elif len(beam_continue) < batch_size:
@@ -98,10 +103,13 @@ class BeamSearchStrategy:
         # TODO is_done
         return tokens, mems
 
-    def finalize(self, tokens):
-        if not self.return_only_end:
+    def finalize(self, tokens, mems):
+        if self.consider_end:
             for i in range(tokens.shape[0]):
                 self._add_end_beams(self.cached_beam_scores[i], tokens[i])
-        ret = self.end_beams
+            mems = None
+            ret = self.end_beams
+        else:
+            ret = tokens
         self._init_cache()
-        return ret
+        return ret, mems
diff --git a/examples/cogview2/inference_cogview2.py b/examples/cogview2/inference_cogview2.py
index 823d4918cde2b1317edf5869c90abbfa2c034f87..9a423d3d7daa08c3a65c88a8cb38d10f952c1e03 100644
--- a/examples/cogview2/inference_cogview2.py
+++ b/examples/cogview2/inference_cogview2.py
@@ -53,7 +53,7 @@ def main(args):
         print('raw text: ', raw_text)
         text = query_template.format(raw_text)
         seq = tokenizer.parse_query(text, img_size=args.img_size)
-        if len(seq) > 1088:
+        if len(seq) > 1088:  
             raise ValueError('text too long.')
         # calibrate text length
         txt_len = seq.index(tokenizer['[BASE]'])
@@ -67,11 +67,12 @@ def main(args):
         assert args.batch_size < mbz or args.batch_size % mbz == 0
         output_list = []
         for tim in range(max(args.batch_size // mbz, 1)):
-            output0, mems = filling_sequence(model0, seq.clone(),
+            output0 = filling_sequence(model0, seq.clone(),
                     batch_size=min(args.batch_size, mbz),
                     strategy=strategy0,
                     log_attention_weights=log_attention_weights
-                    )
+                    )[0]
+             # auto del mems to save CUDA memory as possible
             imgs = [tr(tokenizer.img_tokenizer.DecodeIds(x[-1025:-1].tolist())) for x in output0]
             blur64 = tokenizer.img_tokenizer.EncodeAsIds(torch.cat(imgs, dim=0).to(args.device), add_normalization=True) # [batch_size, 4096]
             len_tim = output0.shape[0]
diff --git a/examples/cogview2/scripts/large_scale_text2image_cogview2.sh b/examples/cogview2/scripts/large_scale_text2image_cogview2.sh
index fb1eb174a23ceaa394b0181ffe0f09764fc64e32..91aa80eed30b14e0abf8394227ba001df906e8c8 100755
--- a/examples/cogview2/scripts/large_scale_text2image_cogview2.sh
+++ b/examples/cogview2/scripts/large_scale_text2image_cogview2.sh
@@ -1,6 +1,6 @@
 #!/bin/bash
 
-NUM_WORKERS=4
+NUM_WORKERS=5
 NUM_GPUS_PER_WORKER=8
 MP_SIZE=1
 
@@ -8,7 +8,7 @@ OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
 HOST_FILE_PATH="hostfile"
 # HOST_FILE_PATH="hostfile_single"
 
-CHECKPOINT_PATH=pretrained/cogview/cogview2-base
+CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/cogview/cogview2-base
 NLAYERS=48
 NHIDDEN=2560
 NATT=40
@@ -25,7 +25,7 @@ script_dir=$(dirname $script_path)
 
 gpt_options=" \
        --tokenizer-type cogview \
-       --img-tokenizer-path pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
+       --img-tokenizer-path /dataset/fd5061f6/sat_pretrained/vqvae/l1+ms-ssim+revd_percep.pt \
        --mode inference \
        --distributed-backend nccl \
        --max-sequence-length 1089 \
@@ -42,7 +42,7 @@ gpt_options=" \
        --input-source ./coco30k.txt \
        --output-path coco_samples \
        --batch-size 60 \
-       --max-inference-batch-size 12 \
+       --max-inference-batch-size 6 \
        --with-id \
     "
 
diff --git a/examples/glm/inference_glm.py b/examples/glm/inference_glm.py
index bfc0d6957ce13b7024afba795db7f9d12bd212ab..a1f5b832d5e04dcda0c3a94e787c38b07c0a08d1 100644
--- a/examples/glm/inference_glm.py
+++ b/examples/glm/inference_glm.py
@@ -61,7 +61,12 @@ def main(args):
 
     end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
     # define function for each query
-    strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens)
+    if args.sampling_strategy == 'BaseStrategy':
+        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens)
+    elif args.sampling_strategy == 'BeamSearchStrategy':
+        strategy = BeamSearchStrategy(args.batch_size, length_penalty=args.length_penalty, consider_end=True, end_tokens=end_tokens, no_repeat_ngram_size=args.no_repeat_ngram_size, min_tgt_length=args.min_tgt_length)
+    else:
+        raise ValueError(f'unknown strategy {args.sampling_strategy}')
     
     def process(raw_text):
         if args.with_id:
@@ -77,15 +82,6 @@ def main(args):
         print('raw text: {}\n'.format(raw_text))
         if len(seq) > args.max_sequence_length:
             raise ValueError('text too long.')
-        
-        # find mask tokens positions
-        # mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
-        # mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
-        # mask_positions = []
-        # context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device)
-        # for token in mask_tokens:
-        #     mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
-        # mask_positions.sort()
 
         # generation
         mbz = args.max_inference_batch_size
@@ -112,12 +108,12 @@ def main(args):
                 input_seq = torch.cuda.LongTensor(
                     seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length - len(seq) - 1),
                     device=args.device)
-                output, _mems = filling_sequence(model, input_seq,
+                output = filling_sequence(model, input_seq,
                         batch_size=min(args.batch_size, mbz),
                         strategy=strategy,
                         log_attention_weights=None,
                         get_masks_and_position_ids=get_func
-                        ) # we don't use mems, fill back
+                        )[0] # we don't use mems, fill back
                 if isinstance(output, torch.Tensor): # different strategies
                     output = list(output)
 
@@ -159,7 +155,7 @@ def main(args):
 
 if __name__ == "__main__":
     py_parser = argparse.ArgumentParser(add_help=False)
-
+    py_parser.add_argument('--sampling-strategy', type=str, default='BaseStrategy', help='type name of sampling strategy')
     known, args_list = py_parser.parse_known_args()
     args = get_args(args_list)
     args = argparse.Namespace(**vars(args), **vars(known))
diff --git a/examples/glm/scripts/generate_glm.sh b/examples/glm/scripts/generate_glm.sh
index eea0ed2742e8762c9294a406f64dcc670540552f..9e1c1631005dbe9ea681e56b28b871d570faf07a 100755
--- a/examples/glm/scripts/generate_glm.sh
+++ b/examples/glm/scripts/generate_glm.sh
@@ -1,29 +1,6 @@
 #!/bin/bash
 CHECKPOINT_PATH=/dataset/fd5061f6/sat_pretrained/glm
 
-# MODEL_ARGS="--block-lm \
-#             --cloze-eval \
-#             --num-layers 24 \
-#             --hidden-size 1024 \
-#             --num-attention-heads 16 \
-#             --max-sequence-length 513 \
-#             --tokenizer-model-type roberta \
-#             --tokenizer-type glm_GPT2BPETokenizer \
-#             --load ${CHECKPOINT_PATH}/glm-roberta-large-blank"
-
-#MODEL_TYPE="blocklm-10B"
-#MODEL_ARGS="--block-lm \
-#            --cloze-eval \
-#            --task-mask \
-#            --num-layers 48 \
-#            --hidden-size 4096 \
-#            --num-attention-heads 64 \
-#            --max-sequence-length 1025 \
-#            --tokenizer-model-type gpt2 \
-#            --tokenizer-type glm_GPT2BPETokenizer \
-#            --old-checkpoint \
-#            --load ${CHECKPOINT_PATH}/glm-en-10b"
-
 source $1
 MPSIZE=1
 MAXSEQLEN=512
@@ -52,6 +29,8 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
        --temperature $TEMP \
        --top_k $TOPK \
        --output-path samples_glm \
-       --batch-size 1 \
+       --batch-size 2 \
        --out-seq-length 200 \
-       --mode inference
+       --mode inference \
+       --input-source ./input.txt \
+       --sampling-strategy BeamSearchStrategy