diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py
index cccc7c79955c5231f47e6b9ace7e556b83ddb511..541d9091d565e27685c64b68775c271842252438 100644
--- a/generation/autoregressive_sampling.py
+++ b/generation/autoregressive_sampling.py
@@ -14,7 +14,7 @@ import random
 import torch
 from .sampling_strategies import BaseStrategy
 
-def get_masks_and_position_ids(seq):
+def get_masks_and_position_ids_default(seq):
     tokens = seq.unsqueeze(0)
 
     attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
@@ -36,7 +36,6 @@ def update_mems(hiddens, mems, max_memory_length):
     memory_length = mems.shape[2] if mems is not None else 0
     query_length = hiddens.shape[2]
     new_memory_length = min(max_memory_length, memory_length + query_length)
-    new_mems = []
     with torch.no_grad():
         if new_memory_length <= query_length:
             return hiddens[:, :, -new_memory_length:]
@@ -55,10 +54,16 @@ def filling_sequence(
         batch_size,
         strategy=BaseStrategy(),
         max_memory_length=100000,
-        log_attention_weights=None
+        log_attention_weights=None,
+        get_masks_and_position_ids=get_masks_and_position_ids_default,
+        mems=None
         ):
     '''
         seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
+        mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
+            cache, should be first mems.shape[1] parts of context_tokens.
+            mems are the first-level citizens here, but we don't assume what is memorized.
+            input mems are used when multi-phase generation.
     '''
     assert len(seq.shape) == 1
 
@@ -72,9 +77,7 @@ def filling_sequence(
     attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
     # initialize generation
     counter = context_length - 1 # Last fixed index is ``counter'' 
-    index = 0 # Next forward starting index, also the length of cache.
-    mems = None # mems are the first-level citizens here, but we don't assume what is memorized.
-        
+    index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
     # step-by-step generation
     while counter < len(seq) - 1:
         # Now, we want to generate seq[counter + 1],
@@ -83,7 +86,7 @@ def filling_sequence(
         if seq[counter + 1] >= 0: # provided
             tokens = torch.cat(
                 (
-                    tokens, 
+                tokens, 
                     seq[counter+1: counter+2].expand(tokens.shape[0], 1)
                 ), dim=1
             )
@@ -92,13 +95,16 @@ def filling_sequence(
 
         # forward
         if log_attention_weights is not None:
-            model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
-        kw_tensors = {'mems': mems} if mems is not None else {}
+            log_attention_weights_part = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
+        else:
+            log_attention_weights_part = None
+
         logits, *mem_kv = model(
             tokens[:, index:], 
             position_ids[..., index: counter+1],
             attention_mask[..., index: counter+1, :counter+1], # TODO memlen
-            **kw_tensors # if no mems, cannot pass
+            mems=mems,
+            log_attention_weights=log_attention_weights_part
         )
         mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
         counter += 1
@@ -107,6 +113,6 @@ def filling_sequence(
         logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
         tokens = tokens.expand(batch_size, -1)
         tokens, mems = strategy.forward(logits, tokens, mems)
-        
-    model.log_attention_weights = None
-    return tokens
\ No newline at end of file
+        if strategy.is_done:
+            break
+    return strategy.finalize(tokens, mems)
\ No newline at end of file
diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py
index 20339941e16802cb60c88613a5bafee76db75b74..a5268e5ab4db13e6c8ede944f81291505560e4c6 100644
--- a/generation/sampling_strategies/base_strategy.py
+++ b/generation/sampling_strategies/base_strategy.py
@@ -15,7 +15,7 @@ import torch
 import torch.nn.functional as F
 
 
-def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
     # This function has been mostly taken from huggingface conversational ai code at
     # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
 
@@ -69,10 +69,11 @@ class BaseStrategy:
         logits = top_k_logits(logits, self.topk, self.top_p)
         probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
         pred = torch.multinomial(probs, num_samples=1)
-        if pred.item() in self.end_tokens:
+        if pred.numel() == 1 and pred.item() in self.end_tokens:
             self._is_done = True
         tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
         return tokens, mems
 
     def finalize(self, tokens, mems):
+        self._is_done = False
         return tokens, mems
diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py
index ec205e071590c1683d2e96bbac80ff57db31d2df..b30926294ce26e4a890b51b89896b82de88b285a 100644
--- a/generation/sampling_strategies/beam_search_strategy.py
+++ b/generation/sampling_strategies/beam_search_strategy.py
@@ -9,459 +9,99 @@
 # here put the import lib
 import torch
 import torch.nn.functional as F
-from abc import ABC, abstractmethod
-from collections import UserDict
-from typing import Optional, Tuple, List, Iterable, Union
-
-
-class BeamScorer(ABC):
-    """
-    Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
-    :meth:`~transformers.PretrainedModel.beam_sample`.
-    """
-
-    @abstractmethod
-    def process(
-            self,
-            input_ids: torch.LongTensor,
-            next_scores: torch.FloatTensor,
-            next_tokens: torch.LongTensor,
-            next_indices: torch.LongTensor,
-            **kwargs
-    ) -> Tuple[torch.Tensor]:
-        raise NotImplementedError("This is an abstract method.")
-
-    @abstractmethod
-    def finalize(
-            self,
-            input_ids: torch.LongTensor,
-            next_scores: torch.FloatTensor,
-            next_tokens: torch.LongTensor,
-            next_indices: torch.LongTensor,
-            **kwargs
-    ) -> torch.LongTensor:
-        raise NotImplementedError("This is an abstract method.")
-
-
-class BeamSearchScorer(BeamScorer):
-    r"""
-    :class:`transformers.BeamScorer` implementing standard beam search decoding.
-
-    Adapted in part from `Facebook's XLM beam search code
-    <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
-
-    Args:
-        batch_size (:obj:`int`):
-            Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
-        max_length (:obj:`int`):
-            The maximum length of the sequence to be generated.
-        num_beams (:obj:`int`):
-            Number of beams for beam search.
-        device (:obj:`torch.device`):
-            Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
-            :obj:`BeamSearchScorer` will be allocated.
-        length_penalty (:obj:`float`, `optional`, defaults to 1.0):
-            Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
-            model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
-            sequences.
-        do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
-            Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
-        num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
-            The number of beam hypotheses that shall be returned upon calling
-            :meth:`~transformer.BeamSearchScorer.finalize`.
-    """
-
-    def __init__(
-            self,
-            batch_size: int,
-            max_length: int,
-            num_beams: int,
-            device: Union[torch.device, str],
-            length_penalty: Optional[float] = 1.0,
-            do_early_stopping: Optional[bool] = False,
-            num_beam_hyps_to_keep: Optional[int] = 1,
-    ):
-        self.max_length = max_length
-        self.num_beams = num_beams
-        self.device = device
-        self.length_penalty = length_penalty
-        self.do_early_stopping = do_early_stopping
-        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
-
-        self._is_init = False
-        self._beam_hyps = [
-            BeamHypotheses(
-                num_beams=self.num_beams,
-                max_length=self.max_length,
-                length_penalty=self.length_penalty,
-                early_stopping=self.do_early_stopping,
-            )
-            for _ in range(batch_size)
-        ]
-        self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
-
-        # if not isinstance(num_beams, int) or num_beams <= 1:
-        #     raise ValueError(
-        #         f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
-        #     )
-
-    @property
-    def is_done(self) -> bool:
-        return self._done.all()
-
-    def process(
-            self,
-            input_ids: torch.LongTensor,
-            next_scores: torch.FloatTensor,
-            next_tokens: torch.LongTensor,
-            next_indices: torch.LongTensor,
-            pad_token_id: Optional[int] = None,
-            eos_token_id: Optional[int] = None,
-            mems=None
-    ) -> Tuple[torch.Tensor]:
-        cur_len = input_ids.shape[-1]
-        batch_size = len(self._beam_hyps)
-        assert batch_size == (input_ids.shape[0] // self.num_beams)
-        if isinstance(eos_token_id, int):
-            eos_token_id = [eos_token_id]
-        device = next_scores.device
-        next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
-        next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
-        next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
-
-        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
-            if self._done[batch_idx]:
-                assert (
-                        len(beam_hyp) >= self.num_beams
-                ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
-                assert (
-                        eos_token_id is not None and pad_token_id is not None
-                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
-                # pad the batch
-                next_beam_scores[batch_idx, :] = 0
-                next_beam_tokens[batch_idx, :] = pad_token_id
-                next_beam_indices[batch_idx, :] = 0
-                continue
-
-            # next tokens for this sentence
-            beam_idx = 0
-            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
-                    zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
-            ):
-                batch_beam_idx = batch_idx * self.num_beams + next_index
-                # add to generated hypotheses if end of sentence
-                if (eos_token_id is not None) and (next_token.item() in eos_token_id):
-                    # if beam_token does not belong to top num_beams tokens, it should not be added
-                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
-                    if is_beam_token_worse_than_top_num_beams:
-                        continue
-                    beam_hyp.add(
-                        input_ids[batch_beam_idx].clone(),
-                        next_score.item(),
-                        mems=[mem[[next_index.item()]] for mem in mems] if mems else None
-                    )
-                else:
-                    # add next predicted token since it is not eos_token
-                    next_beam_scores[batch_idx, beam_idx] = next_score
-                    next_beam_tokens[batch_idx, beam_idx] = next_token
-                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
-                    beam_idx += 1
-
-                # once the beam for next step is full, don't add more tokens to it.
-                if beam_idx == self.num_beams:
-                    break
-
-            if beam_idx < self.num_beams:
-                raise ValueError(
-                    f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
-                )
-
-            # Check if we are done so that we can save a pad step if all(done)
-            self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
-                next_scores[batch_idx].max().item(), cur_len
-            )
-
-        return UserDict(
-            {
-                "next_beam_scores": next_beam_scores.view(-1),
-                "next_beam_tokens": next_beam_tokens.view(-1),
-                "next_beam_indices": next_beam_indices.view(-1),
-            }
-        )
-
-    def finalize(
-            self,
-            input_ids: torch.LongTensor,
-            final_beam_scores: torch.FloatTensor,
-            pad_token_id: Optional[int] = None,
-            eos_token_id: Optional[int] = None,
-            mems=None
-    ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]:
-        batch_size = len(self._beam_hyps)
-        # finalize all open beam hypotheses and add to generated hypotheses
-        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
-            if self._done[batch_idx]:
-                continue
-
-            # need to add best num_beams hypotheses to generated hyps
-            for beam_id in range(self.num_beams):
-                batch_beam_idx = batch_idx * self.num_beams + beam_id
-                final_score = final_beam_scores[batch_beam_idx].item()
-                final_tokens = input_ids[batch_beam_idx]
-                beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None)
-
-        # select the best hypotheses
-        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
-        best = []
-
-        # retrieve best hypotheses
-        for i, beam_hyp in enumerate(self._beam_hyps):
-            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
-            for j in range(self.num_beam_hyps_to_keep):
-                score, best_hyp, mems = sorted_hyps.pop()
-                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
-                best.append((best_hyp, mems, score))
-
-        # prepare for adding eos
-        sent_max_len = min(sent_lengths.max().item(), self.max_length)
-        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
-        scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep)
-        # shorter batches are padded if needed
-        if sent_lengths.min().item() != sent_lengths.max().item():
-            assert pad_token_id is not None, "`pad_token_id` has to be defined"
-            decoded.fill_(pad_token_id)
-
-        # fill with hypotheses and eos_token_id if the latter fits in
-        mems = []
-        for i, (hypo, mem, score) in enumerate(best):
-            scores[i] = score
-            decoded[i, : sent_lengths[i]] = hypo
-            if sent_lengths[i] < sent_max_len:
-                decoded[i, sent_lengths[i]] = eos_token_id
-            mems.append(mem)
-        mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None
-        return decoded, mems, scores
-
-
-class BeamHypotheses:
-    def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
-        """
-        Initialize n-best list of hypotheses.
-        """
-        self.max_length = max_length - 1  # ignoring bos_token
-        self.length_penalty = length_penalty
-        self.early_stopping = early_stopping
-        self.num_beams = num_beams
-        self.beams = []
-        self.worst_score = 1e9
-
-    def __len__(self):
-        """
-        Number of hypotheses in the list.
-        """
-        return len(self.beams)
-
-    def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None):
-        """
-        Add a new hypothesis to the list.
-        """
-        score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty)
-        if len(self) < self.num_beams or score > self.worst_score:
-            self.beams.append((score, hyp, mems))
-            if len(self) > self.num_beams:
-                sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
-                del self.beams[sorted_next_scores[0][1]]
-                self.worst_score = sorted_next_scores[1][0]
-            else:
-                self.worst_score = min(score, self.worst_score)
-
-    def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
-        """
-        If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
-        one in the heap, then we are done with this sentence.
-        """
-
-        if len(self) < self.num_beams:
-            return False
-        elif self.early_stopping:
-            return True
-        else:
-            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
-            ret = self.worst_score >= cur_score
-            return ret
-
-
-class LogitsProcessor(ABC):
-    """Abstract base class for all logit processors that can be applied during generation."""
-
-    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
-        """Torch method for processing logits."""
-        raise NotImplementedError(
-            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
-        )
-
-
-class LogitsProcessorList(list):
-    """
-    This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
-    :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
-    list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
-    :class:`~transformers.LogitsProcessor` to the inputs.
-    """
-
-    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
-        for processor in self:
-            scores = processor(input_ids, scores)
-        return scores
-
-
-class MinLengthLogitsProcessor(LogitsProcessor):
-    r"""
-    :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
-
-    Args:
-        min_length (:obj:`int`):
-            The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
-        eos_token_id (:obj:`int`):
-            The id of the `end-of-sequence` token.
-    """
-
-    def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]):
-        if not isinstance(min_length, int) or min_length < 0:
-            raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
-        if isinstance(eos_token_ids, int):
-            eos_token_ids = [eos_token_ids]
-        for eos_token_id in eos_token_ids:
-            if not isinstance(eos_token_id, int) or eos_token_id < 0:
-                raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
-
-        self.min_length = min_length
-        self.eos_token_ids = eos_token_ids
-
-    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
-        cur_len = input_ids.shape[-1]
-        if cur_len < self.min_length:
-            for eos_token_id in self.eos_token_ids:
-                scores[:, eos_token_id] = -float("inf")
-        return scores
-
-
-class NoRepeatNGramLogitsProcessor(LogitsProcessor):
-    r"""
-    :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
-    <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
-
-    Args:
-        ngram_size (:obj:`int`):
-            All ngrams of size :obj:`ngram_size` can only occur once.
-    """
-
-    def __init__(self, ngram_size: int):
-        if not isinstance(ngram_size, int) or ngram_size <= 0:
-            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
-        self.ngram_size = ngram_size
-
-    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
-        num_batch_hypotheses = scores.shape[0]
-        cur_len = input_ids.shape[-1]
-        banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
-
-        for i, banned_tokens in enumerate(banned_batch_tokens):
-            scores[i, banned_tokens] = -float("inf")
-
-        return scores
-
-    def _calc_banned_ngram_tokens(
-            self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
-    ) -> List[Iterable[int]]:
-        """Copied from fairseq for no_repeat_ngram in beam_search"""
-        if cur_len + 1 < self.ngram_size:
-            # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
-            return [[] for _ in range(num_hypos)]
-        generated_ngrams = [{} for _ in range(num_hypos)]
-        for idx in range(num_hypos):
-            gen_tokens = prev_input_ids[idx].tolist()
-            generated_ngram = generated_ngrams[idx]
-            for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
-                prev_ngram_tuple = tuple(ngram[:-1])
-                generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
-
-        def _get_generated_ngrams(hypo_idx):
-            # Before decoding the next token, prevent decoding of ngrams that have already appeared
-            start_idx = cur_len + 1 - self.ngram_size
-            ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
-            return generated_ngrams[hypo_idx].get(ngram_idx, [])
-
-        banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
-        return banned_tokens
-
 
 class BeamSearchStrategy:
-    def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0,
-                 min_tgt_length=0):
+    def __init__(self, num_beams, length_penalty=1., return_only_end=False,
+                end_tokens=[], invalid_slices=[], no_repeat_ngram_size=0, min_tgt_length=0):
         self.num_beams = num_beams
-        self.max_length = max_length
         self.length_penalty = length_penalty
         self.end_tokens = end_tokens
-        self.no_repeat_ngram_size = no_repeat_ngram_size
+        self.ngram = no_repeat_ngram_size
         self.min_tgt_length = min_tgt_length
-        self.processors = LogitsProcessorList()
-        if min_tgt_length > 0:
-            processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens)
-            self.processors.append(processor)
-        if no_repeat_ngram_size > 0:
-            processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)
-            self.processors.append(processor)
-        self.beam_scorer = BeamSearchScorer(
-            batch_size=1,
-            max_length=max_length,
-            num_beams=num_beams,
-            device=device,
-            length_penalty=length_penalty,
-            do_early_stopping=False,
-        )
-        self.beam_scores = torch.zeros(1, dtype=torch.float, device=device)
-
-    @property
-    def is_done(self) -> bool:
-        return self.beam_scorer.is_done
+        self.invalid_slices = invalid_slices
+        self.return_only_end = return_only_end
+        self._init_cache()
+
+    def _init_cache(self):
+        self.end_beams = [] # list of LongTensors
+        self.end_beams_penalized_scores = [] # list of LongTensors
+        self.cached_beam_scores = 0 # [batch_size]
+        self.cached_beam_ngram_bans = [{} for i in range(self.num_beams)]
+        self.is_done = False
+    
+    def _add_end_beams(self, score, beam):
+        score = score / ((5. + len(beam)) / 6) ** self.length_penalty # Magic number for OpenNMT 
+        for i in range(len(self.end_beams), -1, -1):
+            if i == 0 or score < self.end_beams_penalized_scores[i-1]:
+                break
+        self.num_beams.insert(i, beam)
+        self.end_beams_penalized_scores.insert(i, score)
+
+        self.num_beams = self.num_beams[:self.num_beams]
+        self.end_beams_penalized_scores = self.end_beams_penalized_scores[:self.num_beams]
 
     def forward(self, logits, tokens, mems):
-        last_beam_num = tokens.size(0)
-        logits = self.processors(tokens, logits)
-        next_token_scores = F.log_softmax(logits, dim=-1)
-        next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores)
-        vocab_size = next_token_scores.shape[-1]
-        next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)
+        batch_size, vocab_size = logits.shape
+        seq_len = tokens.shape[-1]
+        logits = logits.float()
+        for invalid_slice in self.invalid_slices:
+            logits[..., invalid_slice] = -65504
+        if self.min_tgt_length > seq_len:
+            for end_token in self.end_tokens:
+                logits[..., end_token] = -65504
+        if self.ngram > 0 and seq_len > self.ngram:
+            for i in range(batch_size):
+                ngram_prefix = tokens[i, -(self.ngram-1):].tolist() # TODO ngram=1
+                for banned_index in self.cached_beam_ngram_bans.get(tuple(ngram_prefix), default=[]):
+                    logits[i, banned_index] = -65504
+        
+        next_token_scores = F.log_softmax(logits, dim=-1) # [batch_size, vocab_size]
+        next_token_scores = next_token_scores + self.cached_beam_scores[:, None].expand_as(next_token_scores)
+        
+        next_token_scores = next_token_scores.view(batch_size * vocab_size)
+
+        probs = F.softmax(next_token_scores)
+        next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams) # [2*nb]
+        next_token_scores = next_token_scores[next_tokens]
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=0)
+        next_tokens = next_tokens[_indices]
+
+        next_indices = next_tokens // vocab_size # batch idx
+        next_tokens = next_tokens % vocab_size
 
-        probs = F.softmax(next_token_scores, dim=-1)
-        next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams)
-        next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
-        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
-        next_tokens = torch.gather(next_tokens, -1, _indices)
+        # select out end beams or continue beams
+        beam_continue = []
+        scores_continue = []
+        bans_continue = []
+        mems_contiue = []
+        for i in range(len(next_tokens)):
+            beam = torch.cat(tokens[next_indices[i]], next_tokens[i:i+1])
+            if int(next_tokens[i]) in self.end_tokens:
+                self._add_end_beams(next_token_scores[i], beam)
+            elif len(beam_continue) < batch_size:
+                beam_continue.append(beam)
+                mems_contiue.append(mems[:, next_indices[i]])
+                # update caches
+                scores_continue.append(next_token_scores[i])
+                if self.ngram > 0:
+                    bans = self.cached_beam_ngram_bans[next_indices[i]].copy()
+                    ngram_prefix = tuple(tokens[next_indices[i], -(self.ngram-1):].tolist()) # TODO ngram=1
+                    bans[ngram_prefix] = bans.get(ngram_prefix, tuple()) + (next_tokens[i],)
+                    bans_continue.append(bans)
+            else:
+                break
+        tokens = torch.stack(beam_continue)
+        mems = torch.stack(mems_contiue, dim=1)
+        self.cached_beam_scores = torch.tensor(scores_continue, device=logits.device)
+        self.cached_beam_ngram_bans = bans_continue
 
-        next_indices = next_tokens // vocab_size
-        next_tokens = next_tokens % vocab_size
-        # stateless
-        tokens = tokens.expand((self.num_beams, -1))
-        beam_outputs = self.beam_scorer.process(
-            tokens,
-            next_token_scores,
-            next_tokens,
-            next_indices,
-            eos_token_id=self.end_tokens,
-            mems=mems
-        )
-        self.beam_scores = beam_outputs["next_beam_scores"]
-        beam_next_tokens = beam_outputs["next_beam_tokens"]
-        beam_idx = beam_outputs["next_beam_indices"]
-        beam_next_tokens = beam_next_tokens.unsqueeze(-1)
-        tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
-        mems = [mem[beam_idx] for mem in mems] if mems else None
+        # TODO is_done
         return tokens, mems
 
-    def finalize(self, tokens, mems):
-        tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores,
-                                                         eos_token_id=self.end_tokens[0],
-                                                         mems=mems)
-        return tokens, mems
+    def finalize(self, tokens):
+        if not self.return_only_end:
+            for i in range(tokens.shape[0]):
+                self._add_end_beams(self.cached_beam_scores[i], tokens[i])
+        ret = self.end_beams
+        self._init_cache()
+        return ret
diff --git a/generation/sampling_strategies/beam_search_strategy_old.py b/generation/sampling_strategies/beam_search_strategy_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeab7989930250aa93b3231399152cd514b4162b
--- /dev/null
+++ b/generation/sampling_strategies/beam_search_strategy_old.py
@@ -0,0 +1,467 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   base_strategy.py
+@Time    :   2021/10/08 22:22:42
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import torch
+import torch.nn.functional as F
+from abc import ABC, abstractmethod
+from collections import UserDict
+from typing import Optional, Tuple, List, Iterable, Union
+
+
+class BeamScorer(ABC):
+    """
+    Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
+    :meth:`~transformers.PretrainedModel.beam_sample`.
+    """
+
+    @abstractmethod
+    def process(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            **kwargs
+    ) -> Tuple[torch.Tensor]:
+        raise NotImplementedError("This is an abstract method.")
+
+    @abstractmethod
+    def finalize(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            **kwargs
+    ) -> torch.LongTensor:
+        raise NotImplementedError("This is an abstract method.")
+
+
+class BeamSearchScorer(BeamScorer):
+    r"""
+    :class:`transformers.BeamScorer` implementing standard beam search decoding.
+
+    Adapted in part from `Facebook's XLM beam search code
+    <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
+
+    Args:
+        batch_size (:obj:`int`):
+            Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
+        max_length (:obj:`int`):
+            The maximum length of the sequence to be generated.
+        num_beams (:obj:`int`):
+            Number of beams for beam search.
+        device (:obj:`torch.device`):
+            Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
+            :obj:`BeamSearchScorer` will be allocated.
+        length_penalty (:obj:`float`, `optional`, defaults to 1.0):
+            Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
+            model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
+            sequences.
+        do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
+            Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
+        num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
+            The number of beam hypotheses that shall be returned upon calling
+            :meth:`~transformer.BeamSearchScorer.finalize`.
+    """
+
+    def __init__(
+            self,
+            batch_size: int,
+            max_length: int,
+            num_beams: int,
+            device: Union[torch.device, str],
+            length_penalty: Optional[float] = 1.0,
+            do_early_stopping: Optional[bool] = False,
+            num_beam_hyps_to_keep: Optional[int] = 1,
+    ):
+        self.max_length = max_length
+        self.num_beams = num_beams
+        self.device = device
+        self.length_penalty = length_penalty
+        self.do_early_stopping = do_early_stopping
+        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
+
+        self._is_init = False
+        self._beam_hyps = [
+            BeamHypotheses(
+                num_beams=self.num_beams,
+                max_length=self.max_length,
+                length_penalty=self.length_penalty,
+                early_stopping=self.do_early_stopping,
+            )
+            for _ in range(batch_size)
+        ]
+        self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
+
+        # if not isinstance(num_beams, int) or num_beams <= 1:
+        #     raise ValueError(
+        #         f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
+        #     )
+
+    @property
+    def is_done(self) -> bool:
+        return self._done.all()
+
+    def process(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[int] = None,
+            mems=None
+    ) -> Tuple[torch.Tensor]:
+        cur_len = input_ids.shape[-1]
+        batch_size = len(self._beam_hyps)
+        assert batch_size == (input_ids.shape[0] // self.num_beams)
+        if isinstance(eos_token_id, int):
+            eos_token_id = [eos_token_id]
+        device = next_scores.device
+        next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
+        next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
+        next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
+
+        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
+            if self._done[batch_idx]:
+                assert (
+                        len(beam_hyp) >= self.num_beams
+                ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
+                assert (
+                        eos_token_id is not None and pad_token_id is not None
+                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
+                # pad the batch
+                next_beam_scores[batch_idx, :] = 0
+                next_beam_tokens[batch_idx, :] = pad_token_id
+                next_beam_indices[batch_idx, :] = 0
+                continue
+
+            # next tokens for this sentence
+            beam_idx = 0
+            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
+                    zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
+            ):
+                batch_beam_idx = batch_idx * self.num_beams + next_index
+                # add to generated hypotheses if end of sentence
+                if (eos_token_id is not None) and (next_token.item() in eos_token_id):
+                    # if beam_token does not belong to top num_beams tokens, it should not be added
+                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
+                    if is_beam_token_worse_than_top_num_beams:
+                        continue
+                    beam_hyp.add(
+                        input_ids[batch_beam_idx].clone(),
+                        next_score.item(),
+                        mems=[mem[[next_index.item()]] for mem in mems] if mems else None
+                    )
+                else:
+                    # add next predicted token since it is not eos_token
+                    next_beam_scores[batch_idx, beam_idx] = next_score
+                    next_beam_tokens[batch_idx, beam_idx] = next_token
+                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
+                    beam_idx += 1
+
+                # once the beam for next step is full, don't add more tokens to it.
+                if beam_idx == self.num_beams:
+                    break
+
+            if beam_idx < self.num_beams:
+                raise ValueError(
+                    f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
+                )
+
+            # Check if we are done so that we can save a pad step if all(done)
+            self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
+                next_scores[batch_idx].max().item(), cur_len
+            )
+
+        return UserDict(
+            {
+                "next_beam_scores": next_beam_scores.view(-1),
+                "next_beam_tokens": next_beam_tokens.view(-1),
+                "next_beam_indices": next_beam_indices.view(-1),
+            }
+        )
+
+    def finalize(
+            self,
+            input_ids: torch.LongTensor,
+            final_beam_scores: torch.FloatTensor,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[int] = None,
+            mems=None
+    ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]:
+        batch_size = len(self._beam_hyps)
+        # finalize all open beam hypotheses and add to generated hypotheses
+        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
+            if self._done[batch_idx]:
+                continue
+
+            # need to add best num_beams hypotheses to generated hyps
+            for beam_id in range(self.num_beams):
+                batch_beam_idx = batch_idx * self.num_beams + beam_id
+                final_score = final_beam_scores[batch_beam_idx].item()
+                final_tokens = input_ids[batch_beam_idx]
+                beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None)
+
+        # select the best hypotheses
+        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
+        best = []
+
+        # retrieve best hypotheses
+        for i, beam_hyp in enumerate(self._beam_hyps):
+            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
+            for j in range(self.num_beam_hyps_to_keep):
+                score, best_hyp, mems = sorted_hyps.pop()
+                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
+                best.append((best_hyp, mems, score))
+
+        # prepare for adding eos
+        sent_max_len = min(sent_lengths.max().item(), self.max_length)
+        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
+        scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep)
+        # shorter batches are padded if needed
+        if sent_lengths.min().item() != sent_lengths.max().item():
+            assert pad_token_id is not None, "`pad_token_id` has to be defined"
+            decoded.fill_(pad_token_id)
+
+        # fill with hypotheses and eos_token_id if the latter fits in
+        mems = []
+        for i, (hypo, mem, score) in enumerate(best):
+            scores[i] = score
+            decoded[i, : sent_lengths[i]] = hypo
+            if sent_lengths[i] < sent_max_len:
+                decoded[i, sent_lengths[i]] = eos_token_id
+            mems.append(mem)
+        mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None
+        return decoded, mems, scores
+
+
+class BeamHypotheses:
+    def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
+        """
+        Initialize n-best list of hypotheses.
+        """
+        self.max_length = max_length - 1  # ignoring bos_token
+        self.length_penalty = length_penalty
+        self.early_stopping = early_stopping
+        self.num_beams = num_beams
+        self.beams = []
+        self.worst_score = 1e9
+
+    def __len__(self):
+        """
+        Number of hypotheses in the list.
+        """
+        return len(self.beams)
+
+    def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None):
+        """
+        Add a new hypothesis to the list.
+        """
+        score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty)
+        if len(self) < self.num_beams or score > self.worst_score:
+            self.beams.append((score, hyp, mems))
+            if len(self) > self.num_beams:
+                sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
+                del self.beams[sorted_next_scores[0][1]]
+                self.worst_score = sorted_next_scores[1][0]
+            else:
+                self.worst_score = min(score, self.worst_score)
+
+    def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
+        """
+        If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
+        one in the heap, then we are done with this sentence.
+        """
+
+        if len(self) < self.num_beams:
+            return False
+        elif self.early_stopping:
+            return True
+        else:
+            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
+            ret = self.worst_score >= cur_score
+            return ret
+
+
+class LogitsProcessor(ABC):
+    """Abstract base class for all logit processors that can be applied during generation."""
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        """Torch method for processing logits."""
+        raise NotImplementedError(
+            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+        )
+
+
+class LogitsProcessorList(list):
+    """
+    This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
+    :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
+    list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
+    :class:`~transformers.LogitsProcessor` to the inputs.
+    """
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        for processor in self:
+            scores = processor(input_ids, scores)
+        return scores
+
+
+class MinLengthLogitsProcessor(LogitsProcessor):
+    r"""
+    :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
+
+    Args:
+        min_length (:obj:`int`):
+            The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
+        eos_token_id (:obj:`int`):
+            The id of the `end-of-sequence` token.
+    """
+
+    def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]):
+        if not isinstance(min_length, int) or min_length < 0:
+            raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
+        if isinstance(eos_token_ids, int):
+            eos_token_ids = [eos_token_ids]
+        for eos_token_id in eos_token_ids:
+            if not isinstance(eos_token_id, int) or eos_token_id < 0:
+                raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
+
+        self.min_length = min_length
+        self.eos_token_ids = eos_token_ids
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        cur_len = input_ids.shape[-1]
+        if cur_len < self.min_length:
+            for eos_token_id in self.eos_token_ids:
+                scores[:, eos_token_id] = -float("inf")
+        return scores
+
+
+class NoRepeatNGramLogitsProcessor(LogitsProcessor):
+    r"""
+    :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
+    <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
+
+    Args:
+        ngram_size (:obj:`int`):
+            All ngrams of size :obj:`ngram_size` can only occur once.
+    """
+
+    def __init__(self, ngram_size: int):
+        if not isinstance(ngram_size, int) or ngram_size <= 0:
+            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
+        self.ngram_size = ngram_size
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        num_batch_hypotheses = scores.shape[0]
+        cur_len = input_ids.shape[-1]
+        banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
+
+        for i, banned_tokens in enumerate(banned_batch_tokens):
+            scores[i, banned_tokens] = -float("inf")
+
+        return scores
+
+    def _calc_banned_ngram_tokens(
+            self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
+    ) -> List[Iterable[int]]:
+        """Copied from fairseq for no_repeat_ngram in beam_search"""
+        if cur_len + 1 < self.ngram_size:
+            # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
+            return [[] for _ in range(num_hypos)]
+        generated_ngrams = [{} for _ in range(num_hypos)]
+        for idx in range(num_hypos):
+            gen_tokens = prev_input_ids[idx].tolist()
+            generated_ngram = generated_ngrams[idx]
+            for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
+                prev_ngram_tuple = tuple(ngram[:-1])
+                generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
+
+        def _get_generated_ngrams(hypo_idx):
+            # Before decoding the next token, prevent decoding of ngrams that have already appeared
+            start_idx = cur_len + 1 - self.ngram_size
+            ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
+            return generated_ngrams[hypo_idx].get(ngram_idx, [])
+
+        banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
+        return banned_tokens
+
+
+class BeamSearchStrategy:
+    def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0,
+                 min_tgt_length=0):
+        self.num_beams = num_beams
+        self.max_length = max_length
+        self.length_penalty = length_penalty
+        self.end_tokens = end_tokens
+        self.no_repeat_ngram_size = no_repeat_ngram_size
+        self.min_tgt_length = min_tgt_length
+        self.processors = LogitsProcessorList()
+        if min_tgt_length > 0:
+            processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens)
+            self.processors.append(processor)
+        if no_repeat_ngram_size > 0:
+            processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)
+            self.processors.append(processor)
+        self.beam_scorer = BeamSearchScorer(
+            batch_size=1,
+            max_length=max_length,
+            num_beams=num_beams,
+            device=device,
+            length_penalty=length_penalty,
+            do_early_stopping=False,
+        )
+        self.beam_scores = torch.zeros(1, dtype=torch.float, device=device)
+
+    @property
+    def is_done(self) -> bool:
+        return self.beam_scorer.is_done
+
+    def forward(self, logits, tokens, mems):
+        last_beam_num = tokens.size(0)
+        logits = self.processors(tokens, logits.float())
+        next_token_scores = F.log_softmax(logits, dim=-1)
+        next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores)
+        vocab_size = next_token_scores.shape[-1]
+        next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)
+
+        probs = F.softmax(next_token_scores, dim=-1)
+        next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams)
+        next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = torch.gather(next_tokens, -1, _indices)
+
+        next_indices = next_tokens // vocab_size
+        next_tokens = next_tokens % vocab_size
+        # stateless
+        tokens = tokens.expand((self.num_beams, -1))
+        beam_outputs = self.beam_scorer.process(
+            tokens,
+            next_token_scores,
+            next_tokens,
+            next_indices,
+            eos_token_id=self.end_tokens,
+            mems=mems
+        )
+        self.beam_scores = beam_outputs["next_beam_scores"]
+        beam_next_tokens = beam_outputs["next_beam_tokens"]
+        beam_idx = beam_outputs["next_beam_indices"]
+        beam_next_tokens = beam_next_tokens.unsqueeze(-1)
+        tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
+        mems = [mem[beam_idx] for mem in mems] if mems else None
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores,
+                                                         eos_token_id=self.end_tokens[0],
+                                                         mems=mems)
+        return tokens, mems
diff --git a/inference_glm.py b/inference_glm.py
index 792efd5e030fa1564a7f52310b261844eb113c03..0e5e44d4a58a54a373140028994d1e5856d94bb9 100644
--- a/inference_glm.py
+++ b/inference_glm.py
@@ -1,12 +1,13 @@
 # -*- encoding: utf-8 -*-
 '''
-@File    :   inference_cogview.py
-@Time    :   2021/10/09 19:41:58
+@File    :   inference_glm.py
+@Time    :   2021/10/22 19:41:58
 @Author  :   Ming Ding
 @Contact :   dm18@mail.tsinghua.edu.cn
 '''
 
 # here put the import lib
+from functools import partial
 import os
 import sys
 import random
@@ -14,163 +15,138 @@ import time
 from datetime import datetime
 import torch
 import torch.nn.functional as F
-
+import argparse
+import stat
 import mpu
+from functools import partial
 from arguments import get_args
 from model.glm_model import GLMModel
+from model.cached_autoregressive_model import CachedAutoregressiveMixin
 from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
-from generation.glm_sampling import filling_sequence_glm
+from generation.autoregressive_sampling import filling_sequence
 from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
+from generation.utils import timed_name, generate_continually
 
+def get_masks_and_position_ids_glm(seq, mask_position, context_length):
+    tokens = seq.unsqueeze(0)
 
-def read_context(tokenizer, args, output=None):
-    terminate_runs, skip_run = 0, 0
-    if mpu.get_model_parallel_rank() == 0:
-        while True:
-            raw_text = input("\nContext prompt (stop to exit) >>> ")
-            if not raw_text:
-                print('Prompt should not be empty!')
-                continue
-            if raw_text == "stop":
-                terminate_runs = 1
-                break
-            generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
-            if args.block_lm and 'MASK]' not in raw_text:
-                raw_text += ' ' + generation_mask
-            if output is not None:
-                output.write(raw_text)
-            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
-            if args.block_lm:
-                context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
-                if not raw_text.endswith('MASK]'):
-                    context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
-            context_length = len(context_tokens)
-
-            if context_length >= args.max_sequence_length:
-                print("\nContext length", context_length,
-                      "\nPlease give smaller context than the window length!")
-                continue
-            break
-    else:
-        context_length = 0
-
-    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
-    torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    terminate_runs = terminate_runs_tensor[0].item()
-
-    if terminate_runs == 1:
-        return terminate_runs, None, None, None
-
-    context_length_tensor = torch.cuda.LongTensor([context_length])
-
-    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    context_length = context_length_tensor[0].item()
-    if mpu.get_model_parallel_rank() == 0:
-        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
-    else:
-        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
-    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
-                                group=mpu.get_model_parallel_group())
-    if mpu.get_model_parallel_rank() != 0:
-        raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
-    return terminate_runs, raw_text, context_tokens_tensor, context_length
+    attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
+    attention_mask.tril_()
+    attention_mask.unsqueeze_(1)
 
+    position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
+    torch.arange(0, context_length, out=position_ids[0, :context_length])
+    position_ids[0, context_length:] = mask_position
+    torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:])
 
-def get_batch(context_tokens, args):
-    tokens = context_tokens
-    tokens = tokens.view(1, -1).contiguous()
-    tokens = tokens.to('cuda')
-
-    # Get the masks and postition ids.
-    if args.block_lm:
-        attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
-        if args.fp16:
-            attention_mask = attention_mask.half()
-        position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
-        if not args.no_block_position:
-            block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
-            position_ids = torch.stack((position_ids, block_position_ids), dim=0)
-        position_ids = position_ids.unsqueeze(0)
-    else:
-        raise NotImplementedError
-
+    position_ids = position_ids.unsqueeze(0)
     return tokens, attention_mask, position_ids
 
 
-def generate_samples(model, tokenizer, args):
-    model.eval()
-    output_path = "./samples"
-    if not os.path.exists(output_path):
-        os.makedirs(output_path)
-    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
-    with torch.no_grad(), open(output_path, "w") as output:
-        while True:
-            torch.distributed.barrier(group=mpu.get_model_parallel_group())
-            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
-            if terminate_runs == 1:
-                return
-            start_time = time.time()
-            if args.block_lm:
-                mems = []
-                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
-                mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
-                mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
-                end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
-                mask_positions = []
-                for token in mask_tokens:
-                    mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
-                mask_positions.sort()
-                if args.no_block_position:
-                    for mask_position in mask_positions:
-                        position_ids[0, mask_position + 1:] += args.out_seq_length
-                _, *mems = model(tokens, position_ids, attention_mask, *mems)
-                for mask_position in mask_positions:
-                    if args.no_block_position:
-                        position = position_ids[0, mask_position].item()
-                    else:
-                        position = mask_position
-                    if args.num_beams > 1:
-                        strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
-                                                      length_penalty=args.length_penalty, end_tokens=end_tokens,
-                                                      no_repeat_ngram_size=args.no_repeat_ngram_size,
-                                                      min_tgt_length=args.min_tgt_length)
-                    else:
-                        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
-                                                end_tokens=end_tokens)
-                    new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
-                                                            end_tokens=end_tokens)
-                    tokens = torch.cat((tokens, new_tokens), dim=1)
-            output_tokens_list = tokens.view(-1).contiguous()
-            if mpu.get_model_parallel_rank() == 0:
-                os.system('clear')
-                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
-                print("\nContext:", raw_text, flush=True)
-                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
-                trim_decode_tokens = decode_tokens
-                print("\nGLM:", trim_decode_tokens, flush=True)
-                output.write(trim_decode_tokens + "\n")
-
-            torch.distributed.barrier(group=mpu.get_model_parallel_group())
-
-
 def main(args):
     initialize_distributed(args)
     tokenizer = prepare_tokenizer(args)
-    # build model
+    # build model 
     model = GLMModel(args)
+    model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
     if args.fp16:
         model = model.half()
     model = model.to(args.device)
     load_checkpoint(model, args)
     set_random_seed(args.seed)
     model.eval()
-    generate_samples(model, tokenizer, args)
 
+    end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
+    # define function for each query
+    strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens)
+    
+    def process(raw_text):
+        if args.with_id:
+            query_id, raw_text = raw_text.split('\t')
+        # add MASK
+        generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
+        if 'MASK]' not in raw_text:
+            raw_text += ' ' + generation_mask
+        seq = tokenizer.EncodeAsIds(raw_text).tokenization
+        seq = [tokenizer.get_command('ENC').Id] + seq
+        if not raw_text.endswith('MASK]'):
+            seq = seq + [tokenizer.get_command('eos').Id]
+        print('raw text: ', raw_text)
+        if len(seq) > args.max_sequence_length:
+            raise ValueError('text too long.')
+        
+        # find mask tokens positions
+        mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
+        mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
+        mask_positions = []
+        context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device)
+        for token in mask_tokens:
+            mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
+        mask_positions.sort()
+        
+        # generation
+        mbz = args.max_inference_batch_size
+        assert args.batch_size < mbz or args.batch_size % mbz == 0
+        output_list = []
+        # call for each position
+        for mp_idx, mask_position in enumerate(mask_positions):
+            get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq))
+            for tim in range(max(args.batch_size // mbz, 1)):
+                input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device)
+                output, _mems = filling_sequence(model, input_seq,
+                        batch_size=min(args.batch_size, mbz),
+                        strategy=strategy,
+                        log_attention_weights=None,
+                        get_masks_and_position_ids=get_func
+                        ) # we don't use mems, fill back
+                if isinstance(output, torch.Tensor): # different strategies
+                    output = list(output)
+                
+                output_list.extend(output)
+
+            # clip -1s and fill back generated things into seq
+            for i in range(len(output_list)):
+                output = output_list[i].tolist()
+                try:
+                    unfinished = output.index(-1)
+                except ValueError:
+                    unfinished = len(output)
+                bog = output.index(tokenizer.get_command('sop').Id)
+                output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog]
+            
+            # prepare the next auto-regressive generation
+            if mp_idx < len(mask_positions) - 1: 
+                # TODO, here to select the best for this time, inverse prompting?
+                seq = output_list[0]
+                output_list = []
+
+        # decoding
+        txts = []
+        for seq in output_list:
+            decode_tokens = tokenizer.DecodeIds(seq)
+            txts.append(decode_tokens)
+
+        # save
+        if args.with_id:
+            full_path = os.path.join(args.output_path, query_id + '.txt')
+        else:
+            prefix = raw_text.replace('/', '')[:20]
+            full_path = timed_name(prefix, '.txt', args.output_path)
+            print(txts[0]) # print the first.
+        with open(full_path, 'w') as fout:
+            for txt in txts:
+                fout.write(txt + '\n')
+        os.chmod(full_path, stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
+
+    os.makedirs(args.output_path, exist_ok=True)
+    generate_continually(process, args.input_source)
 
 if __name__ == "__main__":
-    args = get_args()
+    py_parser = argparse.ArgumentParser(add_help=False)
 
+    known, args_list = py_parser.parse_known_args()
+    args = get_args(args_list)
+    args = argparse.Namespace(**vars(args), **vars(known))
+    
     with torch.no_grad():
-        main(args)
+        main(args)
\ No newline at end of file
diff --git a/inference_glm_old.py b/inference_glm_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..792efd5e030fa1564a7f52310b261844eb113c03
--- /dev/null
+++ b/inference_glm_old.py
@@ -0,0 +1,176 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_cogview.py
+@Time    :   2021/10/09 19:41:58
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import random
+import time
+from datetime import datetime
+import torch
+import torch.nn.functional as F
+
+import mpu
+from arguments import get_args
+from model.glm_model import GLMModel
+from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from generation.glm_sampling import filling_sequence_glm
+from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
+
+
+def read_context(tokenizer, args, output=None):
+    terminate_runs, skip_run = 0, 0
+    if mpu.get_model_parallel_rank() == 0:
+        while True:
+            raw_text = input("\nContext prompt (stop to exit) >>> ")
+            if not raw_text:
+                print('Prompt should not be empty!')
+                continue
+            if raw_text == "stop":
+                terminate_runs = 1
+                break
+            generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
+            if args.block_lm and 'MASK]' not in raw_text:
+                raw_text += ' ' + generation_mask
+            if output is not None:
+                output.write(raw_text)
+            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
+            if args.block_lm:
+                context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
+                if not raw_text.endswith('MASK]'):
+                    context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
+            context_length = len(context_tokens)
+
+            if context_length >= args.max_sequence_length:
+                print("\nContext length", context_length,
+                      "\nPlease give smaller context than the window length!")
+                continue
+            break
+    else:
+        context_length = 0
+
+    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
+    torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    terminate_runs = terminate_runs_tensor[0].item()
+
+    if terminate_runs == 1:
+        return terminate_runs, None, None, None
+
+    context_length_tensor = torch.cuda.LongTensor([context_length])
+
+    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    context_length = context_length_tensor[0].item()
+    if mpu.get_model_parallel_rank() == 0:
+        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+    else:
+        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
+    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    if mpu.get_model_parallel_rank() != 0:
+        raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
+    return terminate_runs, raw_text, context_tokens_tensor, context_length
+
+
+def get_batch(context_tokens, args):
+    tokens = context_tokens
+    tokens = tokens.view(1, -1).contiguous()
+    tokens = tokens.to('cuda')
+
+    # Get the masks and postition ids.
+    if args.block_lm:
+        attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
+        if args.fp16:
+            attention_mask = attention_mask.half()
+        position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
+        if not args.no_block_position:
+            block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
+            position_ids = torch.stack((position_ids, block_position_ids), dim=0)
+        position_ids = position_ids.unsqueeze(0)
+    else:
+        raise NotImplementedError
+
+    return tokens, attention_mask, position_ids
+
+
+def generate_samples(model, tokenizer, args):
+    model.eval()
+    output_path = "./samples"
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
+    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
+    with torch.no_grad(), open(output_path, "w") as output:
+        while True:
+            torch.distributed.barrier(group=mpu.get_model_parallel_group())
+            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
+            if terminate_runs == 1:
+                return
+            start_time = time.time()
+            if args.block_lm:
+                mems = []
+                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
+                mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
+                mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
+                end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
+                mask_positions = []
+                for token in mask_tokens:
+                    mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
+                mask_positions.sort()
+                if args.no_block_position:
+                    for mask_position in mask_positions:
+                        position_ids[0, mask_position + 1:] += args.out_seq_length
+                _, *mems = model(tokens, position_ids, attention_mask, *mems)
+                for mask_position in mask_positions:
+                    if args.no_block_position:
+                        position = position_ids[0, mask_position].item()
+                    else:
+                        position = mask_position
+                    if args.num_beams > 1:
+                        strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
+                                                      length_penalty=args.length_penalty, end_tokens=end_tokens,
+                                                      no_repeat_ngram_size=args.no_repeat_ngram_size,
+                                                      min_tgt_length=args.min_tgt_length)
+                    else:
+                        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                                end_tokens=end_tokens)
+                    new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
+                                                            end_tokens=end_tokens)
+                    tokens = torch.cat((tokens, new_tokens), dim=1)
+            output_tokens_list = tokens.view(-1).contiguous()
+            if mpu.get_model_parallel_rank() == 0:
+                os.system('clear')
+                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
+                print("\nContext:", raw_text, flush=True)
+                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
+                trim_decode_tokens = decode_tokens
+                print("\nGLM:", trim_decode_tokens, flush=True)
+                output.write(trim_decode_tokens + "\n")
+
+            torch.distributed.barrier(group=mpu.get_model_parallel_group())
+
+
+def main(args):
+    initialize_distributed(args)
+    tokenizer = prepare_tokenizer(args)
+    # build model
+    model = GLMModel(args)
+    if args.fp16:
+        model = model.half()
+    model = model.to(args.device)
+    load_checkpoint(model, args)
+    set_random_seed(args.seed)
+    model.eval()
+    generate_samples(model, tokenizer, args)
+
+
+if __name__ == "__main__":
+    args = get_args()
+
+    with torch.no_grad():
+        main(args)
diff --git a/model/base_model.py b/model/base_model.py
index 0e4b5b3dcb2d0fa8f69823d79fe70f016d0c54fd..b3495496e581783616ff151958c06035310c5546 100644
--- a/model/base_model.py
+++ b/model/base_model.py
@@ -14,11 +14,13 @@ import random
 import torch
 
 from mpu import BaseTransformer
+from .mixins import BaseMixin
 
 class BaseModel(torch.nn.Module):
     def __init__(self, args, transformer=None):
         super(BaseModel, self).__init__()
-        self.hooks = self.collect_hooks()
+        self.mixins = torch.nn.ModuleDict()
+        self.collect_hooks_()
         if transformer is not None:
             self.transformer = transformer
         else:
@@ -37,12 +39,25 @@ class BaseModel(torch.nn.Module):
                 parallel_output=True,
                 hooks=self.hooks
             )
-        self.mixins = torch.nn.ModuleList()
         
-    def reinit(self):
+    def reinit(self): # will be called when loading model
         # if some mixins are loaded, overrides this function
-        for m in self.mixins: 
+        for m in self.mixins.values(): 
             m.reinit(self.transformer)
+            
+    def add_mixin(self, name, new_mixin, reinit=False):
+        assert name not in self.mixins
+        assert isinstance(new_mixin, BaseMixin)
+        
+        self.mixins[name] = new_mixin # will auto-register parameters
+        object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
+        
+        if reinit:
+            new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
+        self.collect_hooks_()
+        
+    def get_mixin(self, name):
+        return self.mixins[name]
     
     def forward(self, *args, **kwargs):
         # update hooks as the current model (overrided forwards)
@@ -51,16 +66,28 @@ class BaseModel(torch.nn.Module):
         self.transformer.hooks.update(self.hooks)
         return self.transformer(*args, **kwargs)
         
-    def collect_hooks(self):
+    def collect_hooks_(self):
         names = ['word_embedding_forward', 'position_embedding_forward',
                 'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
                 'branch_embedding_forward', 'branch_final_forward'
                 ]
         hooks = {}
+        hook_origins = {}
         for name in names:
+            for mixin_name, m in self.mixins.items():
+                if hasattr(m, name):
+                    if name in hooks: # conflict
+                        raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
+                    hooks[name] = getattr(m, name)
+                    hook_origins[name] = mixin_name
             if hasattr(self, name):
+                # if name in hooks: # defined in mixins, can override
+                #     print(f'Override {name} in {hook_origins[name]}...')
                 hooks[name] = getattr(self, name)
+                hook_origins[name] = 'model'
+        self.hooks = hooks
+        self.hook_origins = hook_origins
         return hooks
-
+    
     def disable_untrainable_params(self):
         pass
\ No newline at end of file
diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py
index 34702b0cb6f1135406bae1e780c4973dfa814731..97b58b2c464fe02a90752752e8440af2e5852a1c 100755
--- a/model/cached_autoregressive_model.py
+++ b/model/cached_autoregressive_model.py
@@ -13,15 +13,15 @@ import math
 import random
 import torch
 
+from .mixins import BaseMixin
 from .base_model import BaseModel
 from mpu.transformer import standard_attention, split_tensor_along_last_dim
 
-class CachedAutoregressiveModel(BaseModel):
-    def __init__(self, args, transformer=None):
-        super().__init__(args, transformer=transformer)
-        self.log_attention_weights = None
+class CachedAutoregressiveMixin(BaseMixin):
+    def __init__(self):
+        super().__init__()
         
-    def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, **kwargs):
+    def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
         attn_module = self.transformer.layers[layer_id].attention
         mem = mems[layer_id] if mems is not None else None
         
@@ -40,7 +40,7 @@ class CachedAutoregressiveModel(BaseModel):
         query_layer = attn_module._transpose_for_scores(mixed_query_layer)
         key_layer = attn_module._transpose_for_scores(mixed_key_layer)
         value_layer = attn_module._transpose_for_scores(mixed_value_layer)
-        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=self.log_attention_weights)
+        context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights)
         
         context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
         new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
@@ -51,3 +51,8 @@ class CachedAutoregressiveModel(BaseModel):
         new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
             
         return output, new_mem
+
+class CachedAutoregressiveModel(BaseModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        self.add_mixin('auto-regressive', CachedAutoregressiveMixin())
diff --git a/model/cuda2d_model.py b/model/cuda2d_model.py
index c8699e7a66089e6d0dc210a04a2d0c84b27cbf12..cf9867a7a41122f7f37f465631fdad8e55cc8830 100644
--- a/model/cuda2d_model.py
+++ b/model/cuda2d_model.py
@@ -28,10 +28,10 @@ class Cuda2dModel(BaseModel):
     def __init__(self, args, transformer=None):
         super().__init__(args, transformer=transformer)
         additional_seqlen = args.new_sequence_length - args.max_sequence_length
-        self.mixins.append(PositionEmbeddingMixin(
+        self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
             additional_seqlen, args.hidden_size
         ))
-        self.mixins.append(AttentionMixin(
+        self.add_mixin('attention_plus', AttentionMixin(
             num_layers=args.num_layers,
             hidden_size=args.hidden_size
         ))
@@ -41,23 +41,24 @@ class Cuda2dModel(BaseModel):
         self.kernel_size2 = args.kernel_size2
         self.log_attention_weights = None
     
-    def position_embedding_forward(self, position_ids, **kw_tensors):
+    def position_embedding_forward(self, position_ids, **kw_args):
         position = position_ids[..., :self.layout[1]]
         position_plus = position_ids[..., self.layout[1]:]
         position_embeddings = torch.cat(
                 (
                     self.transformer.position_embeddings(position),
-                    self.mixins[0].position_embeddings(position_plus)
+                    self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
                 ),
                 dim=-2
             )
         return position_embeddings
         
-    def attention_forward(self, hidden_states, mask, layer_id=None, **kw_tensors):
+    def attention_forward(self, hidden_states, mask, 
+                        layer_id=None, log_attention_weights=None, **kw_args):
         attn_module = self.transformer.layers[layer_id].attention
         # attention_plus on all layers
-        query_key_value_plus = self.mixins[1].query_key_value[layer_id] 
-        dense_plus = self.mixins[1].dense[layer_id]
+        query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id] 
+        dense_plus = self.get_mixin('attention_plus').dense[layer_id]
         
         # split two parts
         hidden_states_plus = hidden_states[:, self.layout[1]:]
@@ -81,7 +82,7 @@ class Cuda2dModel(BaseModel):
                 kernel_size=self.kernel_size,
                 kernel_size2=self.kernel_size2,
                 attention_dropout=dropout_fn,
-                log_attention_weights=self.log_attention_weights
+                log_attention_weights=log_attention_weights
             )
 
         output_0 = attn_module.dense(context_layer0)
diff --git a/model/glm_model.py b/model/glm_model.py
index 96502d1f7751228b752806f0b80376ece9410170..30593c4866195ef4601972c72871cfabe927924e 100644
--- a/model/glm_model.py
+++ b/model/glm_model.py
@@ -3,16 +3,26 @@ import torch.nn as nn
 
 from .base_model import BaseModel
 from .cached_autoregressive_model import CachedAutoregressiveModel
+from .mixins import BaseMixin
 
-
-class GLMModel(CachedAutoregressiveModel):
-    def __init__(self, args, transformer=None):
-        super().__init__(args, transformer=transformer)
-        self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size)
-        torch.nn.init.normal_(self.transformer.block_position_embeddings.weight, mean=0.0, std=0.02)
-
-    def position_embedding_forward(self, position_ids, *other_tensors):
+class BlockPositionEmbeddingMixin(BaseMixin):
+    def __init__(self, max_sequence_length, hidden_size, init_method_std=0.02):
+        super(BlockPositionEmbeddingMixin, self).__init__()
+        self.max_sequence_length = max_sequence_length
+        self.hidden_size = hidden_size
+        self.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
+        torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)
+    
+    def position_embedding_forward(self, position_ids, **kwargs):
         position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
         position_embeddings = self.transformer.position_embeddings(position_ids)
-        block_position_embeddings = self.transformer.block_position_embeddings(block_position_ids)
+        block_position_embeddings = self.block_position_embeddings(block_position_ids)
         return position_embeddings + block_position_embeddings
+
+class GLMModel(BaseModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        self.add_mixin('block_position_embedding', 
+            BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)
+        )
+    
diff --git a/model/mixins.py b/model/mixins.py
index 125f8b0e2c7ccda41c4125892637046a041f7186..f42b0f270d88d01a33e18d8df11b44d5fd3fbba5 100644
--- a/model/mixins.py
+++ b/model/mixins.py
@@ -20,9 +20,11 @@ class BaseMixin(torch.nn.Module):
     def __init__(self):
         super(BaseMixin, self).__init__()
         # define new params
-    def reinit(self, transformer, *pre_mixins):
+    def reinit(self, *pre_mixins):
         # reload the initial params from previous trained modules
         pass
+    # can also define hook-functions here
+    # ...
 
 class PositionEmbeddingMixin(BaseMixin):
     def __init__(self, additional_sequence_length, hidden_size, 
@@ -32,8 +34,8 @@ class PositionEmbeddingMixin(BaseMixin):
         self.reinit_slice = reinit_slice
         self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
         torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
-    def reinit(self, transformer, *pre_mixins):
-        old_weights = transformer.position_embeddings.weight.data[self.reinit_slice]
+    def reinit(self, *pre_mixins):
+        old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
         old_len, hidden_size = old_weights.shape
         assert hidden_size == self.position_embeddings.weight.shape[-1]
         self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
@@ -58,11 +60,11 @@ class AttentionMixin(BaseMixin):
                 init_method=output_layer_init_method)
                 for layer_id in range(num_layers)
             ])
-    def reinit(self, transformer, *pre_mixins):
-        start_layer = len(transformer.layers) - self.num_layers
+    def reinit(self, *pre_mixins):
+        start_layer = len(self.transformer.layers) - self.num_layers
         assert start_layer >= 0
         for layer_id in range(self.num_layers):
-            old_attention = transformer.layers[start_layer + layer_id].attention
+            old_attention = self.transformer.layers[start_layer + layer_id].attention
             self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
             self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
             self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
diff --git a/mpu/transformer.py b/mpu/transformer.py
index cdc34a84fc0681178f7795cf9d3a43b1d92b257b..2cf38521f49b7f8ee207863fb90926c06e622ba2 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -111,9 +111,9 @@ class SelfAttention(torch.nn.Module):
         tensor = tensor.view(*new_tensor_shape)
         return tensor.permute(0, 2, 1, 3)
 
-    def forward(self, hidden_states, mask, **kw_tensors):
+    def forward(self, hidden_states, mask, **kw_args):
         if 'attention_forward' in self.hooks:
-            return self.hooks['attention_forward'](hidden_states, mask, **kw_tensors, layer_id=self.layer_id)
+            return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
         else:
             mixed_raw_layer = self.query_key_value(hidden_states)
             (mixed_query_layer,
@@ -162,9 +162,9 @@ class MLP(torch.nn.Module):
         )
         self.dropout = torch.nn.Dropout(output_dropout_prob)
 
-    def forward(self, hidden_states, **kw_tensors):
+    def forward(self, hidden_states, **kw_args):
         if 'mlp_forward' in self.hooks:
-            output = self.hooks['mlp_forward'](hidden_states, **kw_tensors, layer_id=self.layer_id)
+            output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id)
         else:
             intermediate_parallel = self.dense_h_to_4h(hidden_states)
             intermediate_parallel = gelu(intermediate_parallel)
@@ -227,7 +227,7 @@ class BaseTransformerLayer(torch.nn.Module):
             hooks=hooks
         )
     
-    def forward(self, hidden_states, mask, **kw_tensors):
+    def forward(self, hidden_states, mask, **kw_args):
         '''
             hidden_states: [batch, seq_len, hidden_size]
             mask: [(1, 1), seq_len, seq_len]
@@ -236,7 +236,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm at the begining of the transformer layer.
         layernorm_output1 = self.input_layernorm(hidden_states)
         # Self attention.
-        attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_tensors)
+        attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args)
 
         # Third LayerNorm
         if self.sandwich_ln:
@@ -247,7 +247,7 @@ class BaseTransformerLayer(torch.nn.Module):
         # Layer norm post the self attention.
         layernorm_output = self.post_attention_layernorm(layernorm_input)
         # MLP.
-        mlp_output = self.mlp(layernorm_output, **kw_tensors)
+        mlp_output = self.mlp(layernorm_output, **kw_args)
 
         # Fourth LayerNorm
         if self.sandwich_ln:
@@ -316,27 +316,25 @@ class BaseTransformer(torch.nn.Module):
         # Final layer norm before output.
         self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
 
-    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_tensors):
+    def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_args):
         # sanity check 
         assert len(input_ids.shape) == 2 
         batch_size, query_length = input_ids.shape
         assert len(attention_mask.shape) == 2 or \
             len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
         assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
-        for k, v in kw_tensors.items():
-            assert isinstance(v, torch.Tensor)
         # branch_input is a new part of input need layer-by-layer update,
         #   but with different hidden_dim and computational routine.
         #   In most cases, you can just ignore it.
 
         # embedding part
         if 'word_embedding_forward' in self.hooks:
-            hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_tensors)
+            hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
         else: # default
             hidden_states = self.word_embeddings(input_ids)
             
         if 'position_embedding_forward' in self.hooks:
-            position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_tensors)
+            position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_args)
         else:
             assert len(position_ids.shape) <= 2
             assert position_ids.shape[-1] == query_length
@@ -346,7 +344,7 @@ class BaseTransformer(torch.nn.Module):
         
         # branch related embedding
         if branch_input is None and 'branch_embedding_forward' in self.hooks:
-            branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_tensors)
+            branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args)
 
         # define custom_forward for checkpointing
         output_per_layers = []
@@ -361,14 +359,14 @@ class BaseTransformer(torch.nn.Module):
                     for i, layer in enumerate(layers_):
                         if len(inputs) > 2:
                             x_, branch_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_tensors
+                                x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
                             )
                         elif 'layer_forward' in self.hooks:
                             x_, output_this_layer = self.hooks['layer_forward'](
-                                x_, mask, layer_id=layer.layer_id, **kw_tensors
+                                x_, mask, layer_id=layer.layer_id, **kw_args
                             )
                         else:
-                            x_, output_this_layer = layer(x_, mask, **kw_tensors)
+                            x_, output_this_layer = layer(x_, mask, **kw_args)
                         output_per_layers_part.append(output_this_layer)
                     return x_, output_per_layers_part
                 return custom_forward
@@ -387,25 +385,25 @@ class BaseTransformer(torch.nn.Module):
             for i, layer in enumerate(self.layers):
                 args = [hidden_states, attention_mask]
                 if branch_input is not None: # customized layer_forward with branch_input
-                    hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_tensors)
+                    hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
                 elif 'layer_forward' in self.hooks: # customized layer_forward
-                    hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_tensors)
+                    hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
                 else:
-                    hidden_states, output_this_layer = layer(*args, **kw_tensors)
+                    hidden_states, output_this_layer = layer(*args, **kw_args)
                 output_per_layers.append(output_this_layer) 
 
         # Final layer norm.
         logits = self.final_layernorm(hidden_states)
         
         if 'final_forward' in self.hooks:
-            logits_parallel = self.hooks['final_forward'](logits, **kw_tensors)
+            logits_parallel = self.hooks['final_forward'](logits, **kw_args)
         else:
             logits_parallel = copy_to_model_parallel_region(logits)
             logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
             
         # branch related embedding
         if branch_input is None and 'branch_final_forward' in self.hooks:
-            branch_input = self.hooks['branch_final_forward'](branch_input, **kw_tensors)
+            branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
 
         if self.parallel_output:
             logits_parallel = gather_from_model_parallel_region(logits_parallel)
diff --git a/scripts/finetune_into_cogview2.sh b/scripts/finetune_into_cogview2.sh
index 3bdf5ea4b78e71165750b9760d2ce2a58f447ae5..1be6d8ff24309d4258cd66269f21f1e02ac92f7c 100755
--- a/scripts/finetune_into_cogview2.sh
+++ b/scripts/finetune_into_cogview2.sh
@@ -50,7 +50,7 @@ gpt_options="${gpt_options}
        --deepspeed \
        --deepspeed_config ${config_json} \
 "
-              
+
 
 run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_cogview2.py $@ ${gpt_options}"
 echo ${run_cmd}
diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh
old mode 100644
new mode 100755
index 143d7b5baa08029075b6c360f1b252a445b5a686..e007900381dd94deb395f94b401ab46a81aa789e
--- a/scripts/generate_glm.sh
+++ b/scripts/generate_glm.sh
@@ -1,7 +1,28 @@
 #!/bin/bash
-CHECKPOINT_PATH=/dataset/fd5061f6/english_data/checkpoints
+CHECKPOINT_PATH=pretrained/glm
 
-source $1
+# MODEL_ARGS="--block-lm \
+#             --cloze-eval \
+#             --num-layers 24 \
+#             --hidden-size 1024 \
+#             --num-attention-heads 16 \
+#             --max-sequence-length 513 \
+#             --tokenizer-model-type roberta \
+#             --tokenizer-type glm_GPT2BPETokenizer \
+#             --load ${CHECKPOINT_PATH}/glm-roberta-large-blank"
+
+MODEL_TYPE="blocklm-10B"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --task-mask \
+            --num-layers 48 \
+            --hidden-size 4096 \
+            --num-attention-heads 64 \
+            --max-sequence-length 1025 \
+            --tokenizer-model-type gpt2 \
+            --tokenizer-type glm_GPT2BPETokenizer \
+            --old-checkpoint \
+            --load ${CHECKPOINT_PATH}/glm-en-10b"
 
 MPSIZE=1
 MAXSEQLEN=512
@@ -29,4 +50,7 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
        --out-seq-length $MAXSEQLEN \
        --temperature $TEMP \
        --top_k $TOPK \
-       --top_p $TOPP
+       --output-path glm_text \
+       --batch-size 1 \
+       --out-seq-length 100 \
+       --mode inference
diff --git a/scripts/pretrain_multiple_nodes.sh b/scripts/pretrain_multiple_nodes.sh
index 5a09f7712528974023124dd510d3762623e7c8d9..c1040a3bc54649742d75be71943a0763438f6f19 100755
--- a/scripts/pretrain_multiple_nodes.sh
+++ b/scripts/pretrain_multiple_nodes.sh
@@ -19,17 +19,17 @@ small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/ziji
 
 config_json="$script_dir/ds_config_zero.json"
 gpt_options=" \
-       --experiment-name pretrain-gpt2-cogview-test \
+       --experiment-name pretrain-gpt2-cogview-small \
        --tokenizer-type cogview \
        --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
        --model-parallel-size ${MP_SIZE} \
        --mode pretrain \
-       --num-layers 12 \
-       --hidden-size 1024 \
-       --num-attention-heads 16 \
+       --num-layers 40 \
+       --hidden-size 2048 \
+       --num-attention-heads 32 \
        --train-iters 200000 \
        --resume-dataloader \
-       --train-data ${small_data} \
+       --train-data ${full_data} \
        --split 949,50,1 \
        --distributed-backend nccl \
        --lr-decay-style cosine \
@@ -38,9 +38,9 @@ gpt_options=" \
        --max-sequence-length 1089 \
        --sandwich-ln \
        --fp16 \
-       --save-interval 2000 \
+       --save-interval 5000 \
        --eval-interval 1000 \
-       --save $main_dir/checkpoints \
+       --save /root/checkpoints \
 "
        # --load pretrained/cogview/cogview-base
 
diff --git a/scripts/text2image_cogview.sh b/scripts/text2image_cogview.sh
index f6bea9668c57024c5f34974dd03d6e8ad76336a1..bcb1ecd9847a2e2e39c6f00c34c89518edee61ee 100755
--- a/scripts/text2image_cogview.sh
+++ b/scripts/text2image_cogview.sh
@@ -33,7 +33,7 @@ MASTER_PORT=${MASTER_PORT} python inference_cogview.py \
        --sandwich-ln \
        --input-source ./input.txt \
        --output-path samples_text2image \
-       --batch-size 8 \
+       --batch-size 4 \
        --max-inference-batch-size 8 \
        $@
 
diff --git a/tokenization/__init__.py b/tokenization/__init__.py
index f465ec4713f52132a32f198731a01825228613fc..427d98eb5e1dc4494b00f51fd9913b960874bd82 100644
--- a/tokenization/__init__.py
+++ b/tokenization/__init__.py
@@ -15,8 +15,6 @@ import torch
 
 
 def get_tokenizer(args=None):
-    kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask,
-              "add_decoder_mask": args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0}
     if not hasattr(get_tokenizer, 'tokenizer'):
         # the first time to load the tokenizer
         if args.tokenizer_type == 'cogview':
@@ -25,15 +23,18 @@ def get_tokenizer(args=None):
                 args.img_tokenizer_path,
                 device=torch.cuda.current_device()
             )
-        elif args.tokenizer_type == "BertWordPieceTokenizer":
-            from .text import BertWordPieceTokenizer
-            get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs)
-        elif args.tokenizer_type == "GPT2BPETokenizer":
-            from .text import GPT2BPETokenizer
-            get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
-        elif args.tokenizer_type == "ChineseSPTokenizer":
-            from .text import ChineseSPTokenizer
-            get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs)
+        elif args.tokenizer_type.startswith('glm_'):
+            kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask,
+              "add_decoder_mask": False} #args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0}
+            if args.tokenizer_type == "glm_BertWordPieceTokenizer":
+                from .text import BertWordPieceTokenizer
+                get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs)
+            elif args.tokenizer_type == "glm_GPT2BPETokenizer":
+                from .text import GPT2BPETokenizer
+                get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
+            elif args.tokenizer_type == "glm_ChineseSPTokenizer":
+                from .text import ChineseSPTokenizer
+                get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
diff --git a/tokenization/text/tokenization_gpt2.py b/tokenization/text/tokenization_gpt2.py
index 318d9209990b74b9aaadc13d1f482923c77bc3ab..8782fe9cb74fc95dc27e3e15cce2339a655ecd80 100644
--- a/tokenization/text/tokenization_gpt2.py
+++ b/tokenization/text/tokenization_gpt2.py
@@ -36,12 +36,12 @@ from ..file_utils import cached_path
 logger = logging.getLogger(__name__)
 
 PRETRAINED_VOCAB_ARCHIVE_MAP = {
-    'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-vocab.json",
-    "roberta": "pretrained/pytorch_pretrained_bert/roberta-vocab.json"
+    'gpt2': "pretrained/english_tokenizer/gpt2-vocab.json",
+    "roberta": "pretrained/english_tokenizer/roberta-vocab.json"
 }
 PRETRAINED_MERGES_ARCHIVE_MAP = {
-    'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-merges.txt",
-    "roberta": "pretrained/pytorch_pretrained_bert/roberta-merges.txt"
+    'gpt2': "pretrained/english_tokenizer/gpt2-merges.txt",
+    "roberta": "pretrained/english_tokenizer/roberta-merges.txt"
 }
 PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
     'gpt2': 1024,
diff --git a/training/model_io.py b/training/model_io.py
index 1655ed20999230346b0dd20c7c547c52a38fb246..18151863ae66a0d6a5e640f71c875d8a61cea92b 100644
--- a/training/model_io.py
+++ b/training/model_io.py
@@ -131,11 +131,6 @@ def load_checkpoint(model, args):
     else: # inference without deepspeed
         module = model
 
-    # Process the checkpoint for GLM
-    if args.block_lm and args.old_checkpoint:
-        sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight']
-        del sd['module']['word_embeddings.weight']
-
     # only load module, other hyperparameters are just for recording.
     missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
     if len(unexpected_keys) > 0: