diff --git a/arguments.py b/arguments.py
index ebefb2fb78ef1af4c57ff506c29654be8391926d..9d87a2344d6e8998deaeb0e64fa5ba4e1a20fb27 100755
--- a/arguments.py
+++ b/arguments.py
@@ -158,6 +158,7 @@ def add_text_generate_args(parser):
     group.add_argument("--top_p", type=float, default=0.0)
     group.add_argument("--top_k", type=int, default=0)
     group.add_argument("--num-beams", type=int, default=1)
+    group.add_argument("--length-penalty", type=float, default=0.0)
     group.add_argument("--out-seq-length", type=int, default=256)
     group.add_argument('--input-source', type=str, default='interactive',
                        help='what input mode to use, interactive or path')
diff --git a/generation/glm_sampling.py b/generation/glm_sampling.py
index 704b54d6ed696aca4f58868a994bfccf0bd52d6a..f024ce326fa8481d9ff613a54be0f2b707365622 100644
--- a/generation/glm_sampling.py
+++ b/generation/glm_sampling.py
@@ -5,34 +5,6 @@ from .autoregressive_sampling import update_mems
 from .sampling_strategies.beam_search_strategy import BeamSearchScorer
 
 
-def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
-    # This function has been mostly taken from huggingface conversational ai code at
-    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
-
-    if top_k > 0:
-        # Remove all tokens with a probability less than the last token of the top-k
-        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-        logits[indices_to_remove] = filter_value
-
-    if top_p > 0.0:
-        # convert to 1D
-        logits = logits.view(logits.size()[1]).contiguous()
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
-
-        # Remove tokens with cumulative probability above the threshold
-        sorted_indices_to_remove = cumulative_probs > top_p
-        # Shift the indices to the right to keep also the first token above the threshold
-        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
-        sorted_indices_to_remove[..., 0] = 0
-        indices_to_remove = sorted_indices[sorted_indices_to_remove]
-        logits[indices_to_remove] = filter_value
-        # going back to 2D
-        logits = logits.view(1, -1).contiguous()
-
-    return logits
-
-
 def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=None, end_tokens=None, device='cuda'):
     tokens = torch.full((1, 1), tokenizer.get_command('sop').Id, device=device, dtype=torch.long)
     counter = 0
@@ -40,7 +12,7 @@ def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=N
         mems = []
     # if end_tokens is None:
     #     end_tokens = [tokenizer.get_command('eos').Id]
-    while counter < args.out_seq_length:
+    while counter < args.out_seq_length - 1:
         last_beam_num = tokens.size(0)
         if args.block_lm:
             if args.no_block_position:
@@ -73,5 +45,5 @@ def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=N
         #     prev = prev.view(1, 1)
         #     tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1)
         counter += 1
-    strategy.finalize(tokens, mems)
+    tokens, mems = strategy.finalize(tokens, mems)
     return tokens, mems
diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py
index 2f71e09703c38106088167808d3be758ae8c9b24..2e6b4f6f481d2dbe24aa6c899656173deeb0b163 100644
--- a/generation/sampling_strategies/__init__.py
+++ b/generation/sampling_strategies/__init__.py
@@ -1,2 +1,3 @@
 from .base_strategy import BaseStrategy
-from .iterative_entfilter_strategy import IterativeEntfilterStrategy
\ No newline at end of file
+from .iterative_entfilter_strategy import IterativeEntfilterStrategy
+from .beam_search_strategy import BeamSearchStrategy
\ No newline at end of file
diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py
index e46a8ca4e505c2891bae2599ebe10746cc54d876..20339941e16802cb60c88613a5bafee76db75b74 100644
--- a/generation/sampling_strategies/base_strategy.py
+++ b/generation/sampling_strategies/base_strategy.py
@@ -14,26 +14,65 @@ import random
 import torch
 import torch.nn.functional as F
 
-def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
-    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-    logits[indices_to_remove] = filter_value     
+
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
+    # This function has been mostly taken from huggingface conversational ai code at
+    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
+
+    if top_k > 0:
+        # Remove all tokens with a probability less than the last token of the top-k
+        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p > 0.0:
+        # convert to 1D
+        logits = logits.view(logits.size()[1]).contiguous()
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+        # Remove tokens with cumulative probability above the threshold
+        sorted_indices_to_remove = cumulative_probs > top_p
+        # Shift the indices to the right to keep also the first token above the threshold
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = 0
+        indices_to_remove = sorted_indices[sorted_indices_to_remove]
+        logits[indices_to_remove] = filter_value
+        # going back to 2D
+        logits = logits.view(1, -1).contiguous()
+
     return logits
 
+
 class BaseStrategy:
-    def __init__(self, invalid_slices=[], temperature=1., topk=200, eps=1e-4):
+    def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
         self.invalid_slices = invalid_slices
         self.temperature = temperature
-        self.topk = topk
+        self.topk = top_k
+        self.top_p = top_p
         self.eps = eps
+        if end_tokens is None:
+            end_tokens = []
+        self.end_tokens = end_tokens
+        self._is_done = False
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done
+
     def forward(self, logits, tokens, mems, temperature=None):
         if temperature is None:
-            temperature = self.temperature 
+            temperature = self.temperature
         logits = logits / temperature
         for invalid_slice in self.invalid_slices:
             logits[..., invalid_slice] = -65504
-            
-        logits = top_k_logits_(logits, self.topk)
-        probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
+
+        logits = top_k_logits(logits, self.topk, self.top_p)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
         pred = torch.multinomial(probs, num_samples=1)
+        if pred.item() in self.end_tokens:
+            self._is_done = True
         tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
         return tokens, mems
+
+    def finalize(self, tokens, mems):
+        return tokens, mems
diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py
index 98ba2d62dc45f9ef528888510180d8f8f919f131..6ce4fd959a266562b64a8522cd59ff7ca1accc8f 100644
--- a/generation/sampling_strategies/beam_search_strategy.py
+++ b/generation/sampling_strategies/beam_search_strategy.py
@@ -14,34 +14,6 @@ from collections import UserDict
 from typing import Optional, Tuple, List, Iterable, Union
 
 
-def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
-    # This function has been mostly taken from huggingface conversational ai code at
-    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
-
-    if top_k > 0:
-        # Remove all tokens with a probability less than the last token of the top-k
-        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-        logits[indices_to_remove] = filter_value
-
-    if top_p > 0.0:
-        # convert to 1D
-        logits = logits.view(logits.size()[1]).contiguous()
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
-
-        # Remove tokens with cumulative probability above the threshold
-        sorted_indices_to_remove = cumulative_probs > top_p
-        # Shift the indices to the right to keep also the first token above the threshold
-        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
-        sorted_indices_to_remove[..., 0] = 0
-        indices_to_remove = sorted_indices[sorted_indices_to_remove]
-        logits[indices_to_remove] = filter_value
-        # going back to 2D
-        logits = logits.view(1, -1).contiguous()
-
-    return logits
-
-
 class BeamScorer(ABC):
     """
     Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
@@ -226,7 +198,7 @@ class BeamSearchScorer(BeamScorer):
             mems=None
     ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]:
         batch_size = len(self._beam_hyps)
-
+        breakpoint()
         # finalize all open beam hypotheses and add to generated hypotheses
         for batch_idx, beam_hyp in enumerate(self._beam_hyps):
             if self._done[batch_idx]:
diff --git a/inference_cogview.py b/inference_cogview.py
index dd4c90f8934ab576f8cbc4e783a7cbe1eea065c8..f5ecb78d9d8ee061c8eb757569381d904e4ab953 100644
--- a/inference_cogview.py
+++ b/inference_cogview.py
@@ -37,8 +37,8 @@ def main(args):
     # define function for each query
     query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}'
     invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
-    strategy = BaseStrategy(invalid_slices, 
-        temperature=args.temperature, topk=args.top_k)
+    strategy = BaseStrategy(invalid_slices,
+                            temperature=args.temperature, top_k=args.top_k)
     
     def process(raw_text):
         if args.with_id:
diff --git a/inference_cogview2.py b/inference_cogview2.py
index 75ad1c6c269d4bd0eb8eb03a65e1b05f45d78c3b..c69e589d71009b40d3620f5052086dbb9509992b 100644
--- a/inference_cogview2.py
+++ b/inference_cogview2.py
@@ -42,8 +42,8 @@ def main(args):
     # define function for each query
     query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1]' if not args.full_query else '{}'
     invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
-    strategy0 = BaseStrategy(invalid_slices, 
-        temperature=args.temperature, topk=args.top_k)
+    strategy0 = BaseStrategy(invalid_slices,
+                             temperature=args.temperature, top_k=args.top_k)
     strategy1 = IterativeEntfilterStrategy(invalid_slices,
         temperature=args.temperature, topk=10) # temperature not used
     tr = transforms.Compose([
diff --git a/inference_glm.py b/inference_glm.py
index cfba93f1fc1715be50b8280ce4ae96ac9b472df4..26d247104888abc71da184169371e8fbd5942bec 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -20,6 +20,7 @@ from arguments import get_args
 from model.glm_model import GLMModel
 from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
 from generation.glm_sampling import filling_sequence_glm
+from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
 
 
 def read_context(tokenizer, args, output=None):
@@ -130,8 +131,14 @@ def generate_samples(model, tokenizer, args):
                         position = position_ids[0, mask_position].item()
                     else:
                         position = mask_position
-                    new_tokens, mems = filling_sequence_glm(model, tokenizer, position, args, mems=mems,
-                                                   end_tokens=end_tokens)
+                    if args.num_beams > 1:
+                        strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
+                                                      length_penalty=args.length_penalty, end_tokens=end_tokens)
+                    else:
+                        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                                end_tokens=end_tokens)
+                    new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
+                                                            end_tokens=end_tokens)
                     tokens = torch.cat((tokens, new_tokens), dim=1)
             output_tokens_list = tokens.view(-1).contiguous()
             if mpu.get_model_parallel_rank() == 0:
diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh
index 30e6d9eb1a367b61acb782012cd90889fb1ef2f6..c3c99c99f01d8aaa10b973e0795ff8d1ce53867c 100644
--- a/scripts/generate_glm.sh
+++ b/scripts/generate_glm.sh
@@ -22,6 +22,8 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
        --mode inference \
        --model-parallel-size $MPSIZE \
        $MODEL_ARGS \
+       --num-beams 4 \
+       --length-penalty 0.7 \
        --fp16 \
        --out-seq-length $MAXSEQLEN \
        --temperature $TEMP \