diff --git a/generation/glm_sampling.py b/generation/glm_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..704b54d6ed696aca4f58868a994bfccf0bd52d6a
--- /dev/null
+++ b/generation/glm_sampling.py
@@ -0,0 +1,77 @@
+import torch
+import torch.nn.functional as F
+import mpu
+from .autoregressive_sampling import update_mems
+from .sampling_strategies.beam_search_strategy import BeamSearchScorer
+
+
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
+    # This function has been mostly taken from huggingface conversational ai code at
+    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
+
+    if top_k > 0:
+        # Remove all tokens with a probability less than the last token of the top-k
+        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p > 0.0:
+        # convert to 1D
+        logits = logits.view(logits.size()[1]).contiguous()
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+        # Remove tokens with cumulative probability above the threshold
+        sorted_indices_to_remove = cumulative_probs > top_p
+        # Shift the indices to the right to keep also the first token above the threshold
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = 0
+        indices_to_remove = sorted_indices[sorted_indices_to_remove]
+        logits[indices_to_remove] = filter_value
+        # going back to 2D
+        logits = logits.view(1, -1).contiguous()
+
+    return logits
+
+
+def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=None, end_tokens=None, device='cuda'):
+    tokens = torch.full((1, 1), tokenizer.get_command('sop').Id, device=device, dtype=torch.long)
+    counter = 0
+    if mems is None:
+        mems = []
+    # if end_tokens is None:
+    #     end_tokens = [tokenizer.get_command('eos').Id]
+    while counter < args.out_seq_length:
+        last_beam_num = tokens.size(0)
+        if args.block_lm:
+            if args.no_block_position:
+                position_ids = torch.full((last_beam_num, 1), mask_position + counter, device=device, dtype=torch.long)
+            else:
+                position_ids = torch.ones(last_beam_num, 2, 1, device=device, dtype=torch.long)
+                position_ids[:, 0] = mask_position
+                position_ids[:, 1] = counter + 1
+            attention_mask = torch.ones(1, 1, device=device, dtype=torch.float)
+        else:
+            position_ids = torch.full((last_beam_num, 1), mask_position + counter - 1, device=device, dtype=torch.long)
+            attention_mask = torch.ones(last_beam_num, 1, 1, args.mem_length + 1, device=device, dtype=torch.float)
+        if args.fp16:
+            attention_mask = attention_mask.half()
+        last_token = tokens[:, -1:]
+        logits, *mem_kvs = model(last_token, position_ids, attention_mask, *mems)
+        mems = update_mems(mem_kvs, mems, max_memory_length=1000000)
+        next_token_logits = logits[:, -1]
+        tokens, mems = strategy.forward(next_token_logits, tokens, mems)
+        if strategy.is_done:
+            break
+        # else:
+        #     next_token_logits /= args.temperature
+        #     next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p)
+        #     log_probs = F.softmax(next_token_logits, dim=-1)
+        #     prev = torch.multinomial(log_probs, num_samples=1)[0]
+        #     is_end = prev.item() in end_tokens
+        #     if is_end:
+        #         break
+        #     prev = prev.view(1, 1)
+        #     tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1)
+        counter += 1
+    strategy.finalize(tokens, mems)
+    return tokens, mems
diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..98ba2d62dc45f9ef528888510180d8f8f919f131
--- /dev/null
+++ b/generation/sampling_strategies/beam_search_strategy.py
@@ -0,0 +1,483 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   base_strategy.py
+@Time    :   2021/10/08 22:22:42
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import torch
+import torch.nn.functional as F
+from abc import ABC, abstractmethod
+from collections import UserDict
+from typing import Optional, Tuple, List, Iterable, Union
+
+
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
+    # This function has been mostly taken from huggingface conversational ai code at
+    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
+
+    if top_k > 0:
+        # Remove all tokens with a probability less than the last token of the top-k
+        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p > 0.0:
+        # convert to 1D
+        logits = logits.view(logits.size()[1]).contiguous()
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+        # Remove tokens with cumulative probability above the threshold
+        sorted_indices_to_remove = cumulative_probs > top_p
+        # Shift the indices to the right to keep also the first token above the threshold
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = 0
+        indices_to_remove = sorted_indices[sorted_indices_to_remove]
+        logits[indices_to_remove] = filter_value
+        # going back to 2D
+        logits = logits.view(1, -1).contiguous()
+
+    return logits
+
+
+class BeamScorer(ABC):
+    """
+    Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
+    :meth:`~transformers.PretrainedModel.beam_sample`.
+    """
+
+    @abstractmethod
+    def process(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            **kwargs
+    ) -> Tuple[torch.Tensor]:
+        raise NotImplementedError("This is an abstract method.")
+
+    @abstractmethod
+    def finalize(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            **kwargs
+    ) -> torch.LongTensor:
+        raise NotImplementedError("This is an abstract method.")
+
+
+class BeamSearchScorer(BeamScorer):
+    r"""
+    :class:`transformers.BeamScorer` implementing standard beam search decoding.
+
+    Adapted in part from `Facebook's XLM beam search code
+    <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
+
+    Args:
+        batch_size (:obj:`int`):
+            Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
+        max_length (:obj:`int`):
+            The maximum length of the sequence to be generated.
+        num_beams (:obj:`int`):
+            Number of beams for beam search.
+        device (:obj:`torch.device`):
+            Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
+            :obj:`BeamSearchScorer` will be allocated.
+        length_penalty (:obj:`float`, `optional`, defaults to 1.0):
+            Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
+            model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
+            sequences.
+        do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
+            Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
+        num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
+            The number of beam hypotheses that shall be returned upon calling
+            :meth:`~transformer.BeamSearchScorer.finalize`.
+    """
+
+    def __init__(
+            self,
+            batch_size: int,
+            max_length: int,
+            num_beams: int,
+            device: Union[torch.device, str],
+            length_penalty: Optional[float] = 1.0,
+            do_early_stopping: Optional[bool] = False,
+            num_beam_hyps_to_keep: Optional[int] = 1,
+    ):
+        self.max_length = max_length
+        self.num_beams = num_beams
+        self.device = device
+        self.length_penalty = length_penalty
+        self.do_early_stopping = do_early_stopping
+        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
+
+        self._is_init = False
+        self._beam_hyps = [
+            BeamHypotheses(
+                num_beams=self.num_beams,
+                max_length=self.max_length,
+                length_penalty=self.length_penalty,
+                early_stopping=self.do_early_stopping,
+            )
+            for _ in range(batch_size)
+        ]
+        self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
+
+        # if not isinstance(num_beams, int) or num_beams <= 1:
+        #     raise ValueError(
+        #         f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
+        #     )
+
+    @property
+    def is_done(self) -> bool:
+        return self._done.all()
+
+    def process(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[int] = None,
+            mems=None
+    ) -> Tuple[torch.Tensor]:
+        cur_len = input_ids.shape[-1]
+        batch_size = len(self._beam_hyps)
+        assert batch_size == (input_ids.shape[0] // self.num_beams)
+        if isinstance(eos_token_id, int):
+            eos_token_id = [eos_token_id]
+        device = next_scores.device
+        next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
+        next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
+        next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
+
+        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
+            if self._done[batch_idx]:
+                assert (
+                        len(beam_hyp) >= self.num_beams
+                ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
+                assert (
+                        eos_token_id is not None and pad_token_id is not None
+                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
+                # pad the batch
+                next_beam_scores[batch_idx, :] = 0
+                next_beam_tokens[batch_idx, :] = pad_token_id
+                next_beam_indices[batch_idx, :] = 0
+                continue
+
+            # next tokens for this sentence
+            beam_idx = 0
+            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
+                    zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
+            ):
+                batch_beam_idx = batch_idx * self.num_beams + next_index
+                # add to generated hypotheses if end of sentence
+                if (eos_token_id is not None) and (next_token.item() in eos_token_id):
+                    # if beam_token does not belong to top num_beams tokens, it should not be added
+                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
+                    if is_beam_token_worse_than_top_num_beams:
+                        continue
+                    beam_hyp.add(
+                        input_ids[batch_beam_idx].clone(),
+                        next_score.item(),
+                        mems=[mem[[next_index.item()]] for mem in mems] if mems else None
+                    )
+                else:
+                    # add next predicted token since it is not eos_token
+                    next_beam_scores[batch_idx, beam_idx] = next_score
+                    next_beam_tokens[batch_idx, beam_idx] = next_token
+                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
+                    beam_idx += 1
+
+                # once the beam for next step is full, don't add more tokens to it.
+                if beam_idx == self.num_beams:
+                    break
+
+            if beam_idx < self.num_beams:
+                raise ValueError(
+                    f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
+                )
+
+            # Check if we are done so that we can save a pad step if all(done)
+            self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
+                next_scores[batch_idx].max().item(), cur_len
+            )
+
+        return UserDict(
+            {
+                "next_beam_scores": next_beam_scores.view(-1),
+                "next_beam_tokens": next_beam_tokens.view(-1),
+                "next_beam_indices": next_beam_indices.view(-1),
+            }
+        )
+
+    def finalize(
+            self,
+            input_ids: torch.LongTensor,
+            final_beam_scores: torch.FloatTensor,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[int] = None,
+            mems=None
+    ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]:
+        batch_size = len(self._beam_hyps)
+
+        # finalize all open beam hypotheses and add to generated hypotheses
+        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
+            if self._done[batch_idx]:
+                continue
+
+            # need to add best num_beams hypotheses to generated hyps
+            for beam_id in range(self.num_beams):
+                batch_beam_idx = batch_idx * self.num_beams + beam_id
+                final_score = final_beam_scores[batch_beam_idx].item()
+                final_tokens = input_ids[batch_beam_idx]
+                beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None)
+
+        # select the best hypotheses
+        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
+        best = []
+
+        # retrieve best hypotheses
+        for i, beam_hyp in enumerate(self._beam_hyps):
+            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
+            for j in range(self.num_beam_hyps_to_keep):
+                score, best_hyp, mems = sorted_hyps.pop()
+                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
+                best.append((best_hyp, mems, score))
+
+        # prepare for adding eos
+        sent_max_len = min(sent_lengths.max().item(), self.max_length)
+        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
+        scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep)
+        # shorter batches are padded if needed
+        if sent_lengths.min().item() != sent_lengths.max().item():
+            assert pad_token_id is not None, "`pad_token_id` has to be defined"
+            decoded.fill_(pad_token_id)
+
+        # fill with hypotheses and eos_token_id if the latter fits in
+        mems = []
+        for i, (hypo, mem, score) in enumerate(best):
+            scores[i] = score
+            decoded[i, : sent_lengths[i]] = hypo
+            if sent_lengths[i] < sent_max_len:
+                decoded[i, sent_lengths[i]] = eos_token_id
+            mems.append(mem)
+        mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None
+        return decoded, mems, scores
+
+
+class BeamHypotheses:
+    def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
+        """
+        Initialize n-best list of hypotheses.
+        """
+        self.max_length = max_length - 1  # ignoring bos_token
+        self.length_penalty = length_penalty
+        self.early_stopping = early_stopping
+        self.num_beams = num_beams
+        self.beams = []
+        self.worst_score = 1e9
+
+    def __len__(self):
+        """
+        Number of hypotheses in the list.
+        """
+        return len(self.beams)
+
+    def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None):
+        """
+        Add a new hypothesis to the list.
+        """
+        score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty)
+        if len(self) < self.num_beams or score > self.worst_score:
+            self.beams.append((score, hyp, mems))
+            if len(self) > self.num_beams:
+                sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
+                del self.beams[sorted_next_scores[0][1]]
+                self.worst_score = sorted_next_scores[1][0]
+            else:
+                self.worst_score = min(score, self.worst_score)
+
+    def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
+        """
+        If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
+        one in the heap, then we are done with this sentence.
+        """
+
+        if len(self) < self.num_beams:
+            return False
+        elif self.early_stopping:
+            return True
+        else:
+            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
+            ret = self.worst_score >= cur_score
+            return ret
+
+
+class LogitsProcessor(ABC):
+    """Abstract base class for all logit processors that can be applied during generation."""
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        """Torch method for processing logits."""
+        raise NotImplementedError(
+            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+        )
+
+
+class LogitsProcessorList(list):
+    """
+    This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
+    :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
+    list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
+    :class:`~transformers.LogitsProcessor` to the inputs.
+    """
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        for processor in self:
+            scores = processor(input_ids, scores)
+        return scores
+
+
+class MinLengthLogitsProcessor(LogitsProcessor):
+    r"""
+    :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
+
+    Args:
+        min_length (:obj:`int`):
+            The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
+        eos_token_id (:obj:`int`):
+            The id of the `end-of-sequence` token.
+    """
+
+    def __init__(self, min_length: int, eos_token_id: int):
+        if not isinstance(min_length, int) or min_length < 0:
+            raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
+
+        if not isinstance(eos_token_id, int) or eos_token_id < 0:
+            raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
+
+        self.min_length = min_length
+        self.eos_token_id = eos_token_id
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        cur_len = input_ids.shape[-1]
+        if cur_len < self.min_length:
+            scores[:, self.eos_token_id] = -float("inf")
+        return scores
+
+
+class NoRepeatNGramLogitsProcessor(LogitsProcessor):
+    r"""
+    :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
+    <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
+
+    Args:
+        ngram_size (:obj:`int`):
+            All ngrams of size :obj:`ngram_size` can only occur once.
+    """
+
+    def __init__(self, ngram_size: int):
+        if not isinstance(ngram_size, int) or ngram_size <= 0:
+            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
+        self.ngram_size = ngram_size
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        num_batch_hypotheses = scores.shape[0]
+        cur_len = input_ids.shape[-1]
+        banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
+
+        for i, banned_tokens in enumerate(banned_batch_tokens):
+            scores[i, banned_tokens] = -float("inf")
+
+        return scores
+
+    def _calc_banned_ngram_tokens(
+            self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
+    ) -> List[Iterable[int]]:
+        """Copied from fairseq for no_repeat_ngram in beam_search"""
+        if cur_len + 1 < self.ngram_size:
+            # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
+            return [[] for _ in range(num_hypos)]
+        generated_ngrams = [{} for _ in range(num_hypos)]
+        for idx in range(num_hypos):
+            gen_tokens = prev_input_ids[idx].tolist()
+            generated_ngram = generated_ngrams[idx]
+            for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
+                prev_ngram_tuple = tuple(ngram[:-1])
+                generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
+
+        def _get_generated_ngrams(hypo_idx):
+            # Before decoding the next token, prevent decoding of ngrams that have already appeared
+            start_idx = cur_len + 1 - self.ngram_size
+            ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
+            return generated_ngrams[hypo_idx].get(ngram_idx, [])
+
+        banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
+        return banned_tokens
+
+
+class BeamSearchStrategy:
+    def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda'):
+        self.num_beams = num_beams
+        self.max_length = max_length
+        self.length_penalty = length_penalty
+        self.end_tokens = end_tokens
+        self.beam_scorer = BeamSearchScorer(
+            batch_size=1,
+            max_length=max_length,
+            num_beams=num_beams,
+            device=device,
+            length_penalty=length_penalty,
+            do_early_stopping=False,
+        )
+        self.beam_scores = torch.zeros(1, dtype=torch.float, device=device)
+
+    @property
+    def is_done(self) -> bool:
+        return self.beam_scorer.is_done
+
+    def forward(self, logits, tokens, mems):
+        last_beam_num = tokens.size(0)
+        next_token_scores = F.log_softmax(logits, dim=-1)
+        next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores)
+        vocab_size = next_token_scores.shape[-1]
+        next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)
+
+        probs = F.softmax(next_token_scores, dim=-1)
+        next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams)
+        next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = torch.gather(next_tokens, -1, _indices)
+
+        next_indices = next_tokens // vocab_size
+        next_tokens = next_tokens % vocab_size
+        # stateless
+        tokens = tokens.expand((self.num_beams, -1))
+        beam_outputs = self.beam_scorer.process(
+            tokens,
+            next_token_scores,
+            next_tokens,
+            next_indices,
+            eos_token_id=self.end_tokens,
+            mems=mems
+        )
+        self.beam_scores = beam_outputs["next_beam_scores"]
+        beam_next_tokens = beam_outputs["next_beam_tokens"]
+        beam_idx = beam_outputs["next_beam_indices"]
+        beam_next_tokens = beam_next_tokens.unsqueeze(-1)
+        tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
+        mems = [mem[beam_idx] for mem in mems] if mems else None
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        # TODO check the eos token here
+        tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores,
+                                                         eos_token_id=self.end_tokens[0],
+                                                         mems=mems)
+        return tokens, mems
diff --git a/inference_glm.py b/inference_glm.py
index 0456cc550df2065cc954d9509fd3f370fa1aa26d..cfba93f1fc1715be50b8280ce4ae96ac9b472df4 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -19,10 +19,7 @@ import mpu
 from arguments import get_args
 from model.glm_model import GLMModel
 from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
-from tokenization import get_tokenizer
-from generation.sampling_strategies import BaseStrategy
-from generation.autoregressive_sampling import update_mems
-from generation.utils import timed_name, save_multiple_images, generate_continually
+from generation.glm_sampling import filling_sequence_glm
 
 
 def read_context(tokenizer, args, output=None):
@@ -101,127 +98,6 @@ def get_batch(context_tokens, args):
     return tokens, attention_mask, position_ids
 
 
-def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
-    # This function has been mostly taken from huggingface conversational ai code at
-    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
-
-    if top_k > 0:
-        # Remove all tokens with a probability less than the last token of the top-k
-        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-        logits[indices_to_remove] = filter_value
-
-    if top_p > 0.0:
-        # convert to 1D
-        logits = logits.view(logits.size()[1]).contiguous()
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
-
-        # Remove tokens with cumulative probability above the threshold
-        sorted_indices_to_remove = cumulative_probs > top_p
-        # Shift the indices to the right to keep also the first token above the threshold
-        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
-        sorted_indices_to_remove[..., 0] = 0
-        indices_to_remove = sorted_indices[sorted_indices_to_remove]
-        logits[indices_to_remove] = filter_value
-        # going back to 2D
-        logits = logits.view(1, -1).contiguous()
-
-    return logits
-
-
-def sample_sequence(model, tokenizer, context_tokens, context_length, args, mems=None, end_tokens=None):
-    tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id)
-    counter = 0
-    if mems is None:
-        mems = []
-    if end_tokens is None:
-        end_tokens = [args.eod_token]
-    if args.num_beams > 1:
-        beam_scorer = BeamSearchScorer(
-            batch_size=1,
-            max_length=args.out_seq_length,
-            num_beams=args.num_beams,
-            device=context_tokens.device,
-            length_penalty=args.length_penalty,
-            do_early_stopping=False,
-        )
-        beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device)
-    last_beam_num = 1
-    while counter < args.out_seq_length:
-        if args.block_lm:
-            if args.no_block_position:
-                position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter)
-            else:
-                position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
-                position_ids[:, 0] = context_length
-                position_ids[:, 1] = counter + 1
-            attention_mask = context_tokens.new_ones(1, context_length + counter, device=context_tokens.device,
-                                                     dtype=torch.long)
-        else:
-            position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
-            attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,
-                                                     device=context_tokens.device, dtype=torch.float)
-        last_token = tokens[:, -1:]
-        next_token_logits, *mem_kvs = model(last_token, position_ids, attention_mask, *mems)
-        mems = update_mems(mem_kvs, mems, max_memory_length=1000000)
-        next_token_logits = next_token_logits[:, -1]
-        if args.num_beams > 1:
-            next_token_scores = F.log_softmax(next_token_logits, dim=-1)
-            next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
-            vocab_size = next_token_scores.shape[-1]
-            next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)
-
-            probs = F.softmax(next_token_scores, dim=-1)
-            next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams)
-            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
-            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
-            next_tokens = torch.gather(next_tokens, -1, _indices)
-
-            next_indices = next_tokens // vocab_size
-            next_tokens = next_tokens % vocab_size
-            # stateless
-            tokens = tokens.expand((args.num_beams, -1))
-            beam_outputs = beam_scorer.process(
-                tokens,
-                next_token_scores,
-                next_tokens,
-                next_indices,
-                eos_token_id=end_tokens,
-                mems=mems
-            )
-            beam_scores = beam_outputs["next_beam_scores"]
-            beam_next_tokens = beam_outputs["next_beam_tokens"]
-            beam_idx = beam_outputs["next_beam_indices"]
-            beam_next_tokens = beam_next_tokens.unsqueeze(-1)
-            tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
-            mems = [mem[beam_idx] for mem in mems] if mems else None
-            if beam_scorer.is_done:
-                break
-            last_beam_num = args.num_beams
-        else:
-            next_token_logits /= args.temperature
-            next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p)
-            log_probs = F.softmax(next_token_logits, dim=-1)
-            prev = torch.multinomial(log_probs, num_samples=1)[0]
-            is_end = prev.item() in end_tokens
-            if is_end:
-                break
-            prev = prev.view(1, 1)
-            tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1)
-        counter += 1
-        if not args.block_lm and mpu.get_model_parallel_rank() == 0 and counter % 16 == 0:
-            output_tokens_list = tokens.view(-1).contiguous()
-            decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
-            if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end):
-                os.system('clear')
-                trim_decode_tokens = decode_tokens
-                print(trim_decode_tokens, flush=True)
-    if args.num_beams > 1:
-        tokens, mems = beam_scorer.finalize(tokens, beam_scores, next_tokens, next_indices, eos_token_id=args.eod_token,
-                                            mems=mems)
-    return torch.cat((context_tokens, tokens), dim=1), mems
-
-
 def generate_samples(model, tokenizer, args):
     model.eval()
     output_path = "./samples"
@@ -240,7 +116,7 @@ def generate_samples(model, tokenizer, args):
                 tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
                 mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
                 mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
-                end_tokens = [tokenizer.get_command('eop').Id, args.eod_token]
+                end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
                 mask_positions = []
                 for token in mask_tokens:
                     mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
@@ -254,8 +130,9 @@ def generate_samples(model, tokenizer, args):
                         position = position_ids[0, mask_position].item()
                     else:
                         position = mask_position
-                    tokens, mems = sample_sequence(model, tokenizer, tokens, position, args, mems=mems,
+                    new_tokens, mems = filling_sequence_glm(model, tokenizer, position, args, mems=mems,
                                                    end_tokens=end_tokens)
+                    tokens = torch.cat((tokens, new_tokens), dim=1)
             output_tokens_list = tokens.view(-1).contiguous()
             if mpu.get_model_parallel_rank() == 0:
                 os.system('clear')
@@ -272,7 +149,6 @@ def generate_samples(model, tokenizer, args):
 def main(args):
     initialize_distributed(args)
     tokenizer = prepare_tokenizer(args)
-    args.eod_token = tokenizer.get_command('eos').Id
     # build model
     model = GLMModel(args)
     if args.fp16:
diff --git a/move_weights_glm.py b/move_weights_glm.py
deleted file mode 100644
index 6d3755d22f33f6137add7a4dcee26d31fd106e50..0000000000000000000000000000000000000000
--- a/move_weights_glm.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import sys
-import os
-import torch
-import copy
-
-checkpoint = sys.argv[1]
-target_path = sys.argv[2]
-
-assert os.path.isdir(checkpoint)
-iteration_file = os.path.join(checkpoint, 'latest_checkpointed_iteration.txt')
-if os.path.exists(iteration_file):
-    with open(iteration_file) as fin:
-        iteration = int(fin.read().strip())
-    checkpoint = os.path.join(checkpoint, str(iteration))
-else:
-    iteration = None
-
-os.makedirs(target_path, exist_ok=True)
-if iteration is not None:
-    with open(os.path.join(target_path, "latest"), "w") as output:
-        output.write(str(iteration))
-    target_path = os.path.join(target_path, str(iteration))
-    os.makedirs(target_path, exist_ok=True)
-
-
-filenames = os.listdir(checkpoint)
-filenames = [filename for filename in filenames if filename.startswith("mp_rank_")]
-filenames = sorted(filenames,
-                   key=lambda x: int(x.split('_')[2]))
-filenames = [os.path.join(checkpoint, x) for x in filenames]
-
-for filename in filenames:
-    data = torch.load(filename)
-    state_dict = data['module']
-    state_dict['transformer.word_embeddings.weight'] = state_dict['word_embeddings.weight']
-    del state_dict['word_embeddings.weight']
-    # print(f"Target path: {os.path.join(target_path, os.path.basename(filename))}")
-    torch.save(data, os.path.join(target_path, os.path.basename(filename)))