diff --git a/arguments.py b/arguments.py
index bccbd28237c22e0f8f0777824b03bc1cf3d47712..2152601e8e0512ef3e07fbe88b11a132df66fe3c 100755
--- a/arguments.py
+++ b/arguments.py
@@ -157,6 +157,10 @@ def add_text_generate_args(parser):
     group.add_argument("--temperature", type=float, default=1.0)
     group.add_argument("--top_p", type=float, default=0.0)
     group.add_argument("--top_k", type=int, default=0)
+    group.add_argument("--num-beams", type=int, default=1)
+    group.add_argument("--length-penalty", type=float, default=0.0)
+    group.add_argument("--no-repeat-ngram-size", type=int, default=0)
+    group.add_argument("--min-tgt-length", type=int, default=0)
     group.add_argument("--out-seq-length", type=int, default=256)
     group.add_argument('--input-source', type=str, default='interactive',
                        help='what input mode to use, interactive or path')
@@ -214,12 +218,45 @@ def add_tokenization_args(parser):
 
     group = parser.add_argument_group('Tokenization', 'tokenization configurations')
     group.add_argument('--tokenizer-type', type=str, default='fake', help='type name of tokenizer')
-
+    group.add_argument('--tokenizer-model-type', type=str,
+                       default=None,
+                       help="Model type to use for sentencepiece tokenization \
+                           (one of ['bpe', 'char', 'unigram', 'word']) or \
+                           bert vocab to use for BertWordPieceTokenizer (one of \
+                           ['bert-large-uncased', 'bert-large-cased', etc.])")
     group.add_argument('--img-tokenizer-path', type=str, default=None,
                        help='The checkpoint file path of image tokenizer.')
     return parser
 
 
+def add_glm_args(parser):
+    """Arguments for GLM"""
+    group = parser.add_argument_group('GLM', 'GLM Configurations')
+    group.add_argument('--block-lm', action='store_true', help="whether use the BlockLM pre-training")
+    group.add_argument('--masked-lm', action='store_true', help='whether to use the mlm objective')
+    group.add_argument('--bert-prob', type=float, default=0.5)
+    group.add_argument('--gpt-infill-prob', type=float, default=0.5)
+    group.add_argument('--gpt-min-ratio', type=float, default=0.5)
+    group.add_argument('--gap-sentence-prob', type=float, default=0.0)
+    group.add_argument('--gap-sentence-ratio', type=float, default=0.15)
+    group.add_argument('--avg-block-length', type=int, default=3)
+    group.add_argument('--short-seq-prob', type=float, default=0.0)
+    group.add_argument('--single-span-prob', type=float, default=0.0)
+    group.add_argument('--task-mask', action='store_true', help="Use different mask for generation and blank filling")
+    group.add_argument('--no-shuffle-block', action='store_true', help="not shuffle the blocks when filling the blank")
+    group.add_argument('--no-block-position', action='store_true',
+                       help='Use (rough) absolute positions instead of block positions')
+    group.add_argument('--sentinel-token', action='store_true',
+                       help="Use sentinel (mask) tokens to replace 2d position encoding")
+    group.add_argument('--block-mask-prob', type=float, default=0.0)
+    group.add_argument('--context-mask-ratio', type=float, default=0.0)
+    group.add_argument('--random-position', action='store_true',
+                       help="Use random start position to cover all the position embeddings")
+    group.add_argument('--cloze-eval', action='store_true', help='Evaluation dataset with cloze task')
+    group.add_argument('--old-checkpoint', action='store_true', help="Loading the checkpoint from old libraray")
+    return parser
+
+
     
 def get_args(args_list=None):
     """Parse all the args."""
@@ -232,6 +269,7 @@ def get_args(args_list=None):
     parser = add_tokenization_args(parser)
     parser = add_text_generate_args(parser)
     parser = add_generation_api_args(parser)
+    parser = add_glm_args(parser)
 
     # Include DeepSpeed configuration arguments
     parser = deepspeed.add_config_arguments(parser)
diff --git a/config/model_glm_10B.sh b/config/model_glm_10B.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5ca0df0ed1a753f49baeae6818b0872268b35e91
--- /dev/null
+++ b/config/model_glm_10B.sh
@@ -0,0 +1,12 @@
+MODEL_TYPE="blocklm-10B"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --task-mask \
+            --num-layers 48 \
+            --hidden-size 4096 \
+            --num-attention-heads 64 \
+            --max-sequence-length 1025 \
+            --tokenizer-model-type gpt2 \
+            --tokenizer-type GPT2BPETokenizer \
+            --old-checkpoint \
+            --load ${CHECKPOINT_PATH}/blocklm-10b-1024"
\ No newline at end of file
diff --git a/config/model_glm_roberta_large.sh b/config/model_glm_roberta_large.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5c4eaf200fb2d6b91550db9112beae7f56e1ba98
--- /dev/null
+++ b/config/model_glm_roberta_large.sh
@@ -0,0 +1,11 @@
+MODEL_TYPE="blocklm-roberta-large"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --num-layers 24 \
+            --hidden-size 1024 \
+            --num-attention-heads 16 \
+            --max-sequence-length 513 \
+            --tokenizer-model-type roberta \
+            --tokenizer-type GPT2BPETokenizer \
+            --old-checkpoint \
+            --load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank"
\ No newline at end of file
diff --git a/generation/glm_sampling.py b/generation/glm_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f024ce326fa8481d9ff613a54be0f2b707365622
--- /dev/null
+++ b/generation/glm_sampling.py
@@ -0,0 +1,49 @@
+import torch
+import torch.nn.functional as F
+import mpu
+from .autoregressive_sampling import update_mems
+from .sampling_strategies.beam_search_strategy import BeamSearchScorer
+
+
+def filling_sequence_glm(model, tokenizer, mask_position, strategy, args, mems=None, end_tokens=None, device='cuda'):
+    tokens = torch.full((1, 1), tokenizer.get_command('sop').Id, device=device, dtype=torch.long)
+    counter = 0
+    if mems is None:
+        mems = []
+    # if end_tokens is None:
+    #     end_tokens = [tokenizer.get_command('eos').Id]
+    while counter < args.out_seq_length - 1:
+        last_beam_num = tokens.size(0)
+        if args.block_lm:
+            if args.no_block_position:
+                position_ids = torch.full((last_beam_num, 1), mask_position + counter, device=device, dtype=torch.long)
+            else:
+                position_ids = torch.ones(last_beam_num, 2, 1, device=device, dtype=torch.long)
+                position_ids[:, 0] = mask_position
+                position_ids[:, 1] = counter + 1
+            attention_mask = torch.ones(1, 1, device=device, dtype=torch.float)
+        else:
+            position_ids = torch.full((last_beam_num, 1), mask_position + counter - 1, device=device, dtype=torch.long)
+            attention_mask = torch.ones(last_beam_num, 1, 1, args.mem_length + 1, device=device, dtype=torch.float)
+        if args.fp16:
+            attention_mask = attention_mask.half()
+        last_token = tokens[:, -1:]
+        logits, *mem_kvs = model(last_token, position_ids, attention_mask, *mems)
+        mems = update_mems(mem_kvs, mems, max_memory_length=1000000)
+        next_token_logits = logits[:, -1]
+        tokens, mems = strategy.forward(next_token_logits, tokens, mems)
+        if strategy.is_done:
+            break
+        # else:
+        #     next_token_logits /= args.temperature
+        #     next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p)
+        #     log_probs = F.softmax(next_token_logits, dim=-1)
+        #     prev = torch.multinomial(log_probs, num_samples=1)[0]
+        #     is_end = prev.item() in end_tokens
+        #     if is_end:
+        #         break
+        #     prev = prev.view(1, 1)
+        #     tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1)
+        counter += 1
+    tokens, mems = strategy.finalize(tokens, mems)
+    return tokens, mems
diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py
index 2f71e09703c38106088167808d3be758ae8c9b24..2e6b4f6f481d2dbe24aa6c899656173deeb0b163 100644
--- a/generation/sampling_strategies/__init__.py
+++ b/generation/sampling_strategies/__init__.py
@@ -1,2 +1,3 @@
 from .base_strategy import BaseStrategy
-from .iterative_entfilter_strategy import IterativeEntfilterStrategy
\ No newline at end of file
+from .iterative_entfilter_strategy import IterativeEntfilterStrategy
+from .beam_search_strategy import BeamSearchStrategy
\ No newline at end of file
diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py
index e46a8ca4e505c2891bae2599ebe10746cc54d876..20339941e16802cb60c88613a5bafee76db75b74 100644
--- a/generation/sampling_strategies/base_strategy.py
+++ b/generation/sampling_strategies/base_strategy.py
@@ -14,26 +14,65 @@ import random
 import torch
 import torch.nn.functional as F
 
-def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
-    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-    logits[indices_to_remove] = filter_value     
+
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
+    # This function has been mostly taken from huggingface conversational ai code at
+    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
+
+    if top_k > 0:
+        # Remove all tokens with a probability less than the last token of the top-k
+        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p > 0.0:
+        # convert to 1D
+        logits = logits.view(logits.size()[1]).contiguous()
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+        # Remove tokens with cumulative probability above the threshold
+        sorted_indices_to_remove = cumulative_probs > top_p
+        # Shift the indices to the right to keep also the first token above the threshold
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = 0
+        indices_to_remove = sorted_indices[sorted_indices_to_remove]
+        logits[indices_to_remove] = filter_value
+        # going back to 2D
+        logits = logits.view(1, -1).contiguous()
+
     return logits
 
+
 class BaseStrategy:
-    def __init__(self, invalid_slices=[], temperature=1., topk=200, eps=1e-4):
+    def __init__(self, invalid_slices=[], temperature=1., top_k=200, eps=1e-4, top_p=0.0, end_tokens=None):
         self.invalid_slices = invalid_slices
         self.temperature = temperature
-        self.topk = topk
+        self.topk = top_k
+        self.top_p = top_p
         self.eps = eps
+        if end_tokens is None:
+            end_tokens = []
+        self.end_tokens = end_tokens
+        self._is_done = False
+
+    @property
+    def is_done(self) -> bool:
+        return self._is_done
+
     def forward(self, logits, tokens, mems, temperature=None):
         if temperature is None:
-            temperature = self.temperature 
+            temperature = self.temperature
         logits = logits / temperature
         for invalid_slice in self.invalid_slices:
             logits[..., invalid_slice] = -65504
-            
-        logits = top_k_logits_(logits, self.topk)
-        probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
+
+        logits = top_k_logits(logits, self.topk, self.top_p)
+        probs = F.softmax(logits.float(), dim=-1)  # float is essetial, due to a bug in Pytorch
         pred = torch.multinomial(probs, num_samples=1)
+        if pred.item() in self.end_tokens:
+            self._is_done = True
         tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
         return tokens, mems
+
+    def finalize(self, tokens, mems):
+        return tokens, mems
diff --git a/generation/sampling_strategies/beam_search_strategy.py b/generation/sampling_strategies/beam_search_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec205e071590c1683d2e96bbac80ff57db31d2df
--- /dev/null
+++ b/generation/sampling_strategies/beam_search_strategy.py
@@ -0,0 +1,467 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   base_strategy.py
+@Time    :   2021/10/08 22:22:42
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import torch
+import torch.nn.functional as F
+from abc import ABC, abstractmethod
+from collections import UserDict
+from typing import Optional, Tuple, List, Iterable, Union
+
+
+class BeamScorer(ABC):
+    """
+    Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
+    :meth:`~transformers.PretrainedModel.beam_sample`.
+    """
+
+    @abstractmethod
+    def process(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            **kwargs
+    ) -> Tuple[torch.Tensor]:
+        raise NotImplementedError("This is an abstract method.")
+
+    @abstractmethod
+    def finalize(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            **kwargs
+    ) -> torch.LongTensor:
+        raise NotImplementedError("This is an abstract method.")
+
+
+class BeamSearchScorer(BeamScorer):
+    r"""
+    :class:`transformers.BeamScorer` implementing standard beam search decoding.
+
+    Adapted in part from `Facebook's XLM beam search code
+    <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
+
+    Args:
+        batch_size (:obj:`int`):
+            Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
+        max_length (:obj:`int`):
+            The maximum length of the sequence to be generated.
+        num_beams (:obj:`int`):
+            Number of beams for beam search.
+        device (:obj:`torch.device`):
+            Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
+            :obj:`BeamSearchScorer` will be allocated.
+        length_penalty (:obj:`float`, `optional`, defaults to 1.0):
+            Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
+            model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
+            sequences.
+        do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
+            Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
+        num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
+            The number of beam hypotheses that shall be returned upon calling
+            :meth:`~transformer.BeamSearchScorer.finalize`.
+    """
+
+    def __init__(
+            self,
+            batch_size: int,
+            max_length: int,
+            num_beams: int,
+            device: Union[torch.device, str],
+            length_penalty: Optional[float] = 1.0,
+            do_early_stopping: Optional[bool] = False,
+            num_beam_hyps_to_keep: Optional[int] = 1,
+    ):
+        self.max_length = max_length
+        self.num_beams = num_beams
+        self.device = device
+        self.length_penalty = length_penalty
+        self.do_early_stopping = do_early_stopping
+        self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
+
+        self._is_init = False
+        self._beam_hyps = [
+            BeamHypotheses(
+                num_beams=self.num_beams,
+                max_length=self.max_length,
+                length_penalty=self.length_penalty,
+                early_stopping=self.do_early_stopping,
+            )
+            for _ in range(batch_size)
+        ]
+        self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
+
+        # if not isinstance(num_beams, int) or num_beams <= 1:
+        #     raise ValueError(
+        #         f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
+        #     )
+
+    @property
+    def is_done(self) -> bool:
+        return self._done.all()
+
+    def process(
+            self,
+            input_ids: torch.LongTensor,
+            next_scores: torch.FloatTensor,
+            next_tokens: torch.LongTensor,
+            next_indices: torch.LongTensor,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[int] = None,
+            mems=None
+    ) -> Tuple[torch.Tensor]:
+        cur_len = input_ids.shape[-1]
+        batch_size = len(self._beam_hyps)
+        assert batch_size == (input_ids.shape[0] // self.num_beams)
+        if isinstance(eos_token_id, int):
+            eos_token_id = [eos_token_id]
+        device = next_scores.device
+        next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
+        next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
+        next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
+
+        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
+            if self._done[batch_idx]:
+                assert (
+                        len(beam_hyp) >= self.num_beams
+                ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
+                assert (
+                        eos_token_id is not None and pad_token_id is not None
+                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
+                # pad the batch
+                next_beam_scores[batch_idx, :] = 0
+                next_beam_tokens[batch_idx, :] = pad_token_id
+                next_beam_indices[batch_idx, :] = 0
+                continue
+
+            # next tokens for this sentence
+            beam_idx = 0
+            for beam_token_rank, (next_token, next_score, next_index) in enumerate(
+                    zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
+            ):
+                batch_beam_idx = batch_idx * self.num_beams + next_index
+                # add to generated hypotheses if end of sentence
+                if (eos_token_id is not None) and (next_token.item() in eos_token_id):
+                    # if beam_token does not belong to top num_beams tokens, it should not be added
+                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
+                    if is_beam_token_worse_than_top_num_beams:
+                        continue
+                    beam_hyp.add(
+                        input_ids[batch_beam_idx].clone(),
+                        next_score.item(),
+                        mems=[mem[[next_index.item()]] for mem in mems] if mems else None
+                    )
+                else:
+                    # add next predicted token since it is not eos_token
+                    next_beam_scores[batch_idx, beam_idx] = next_score
+                    next_beam_tokens[batch_idx, beam_idx] = next_token
+                    next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
+                    beam_idx += 1
+
+                # once the beam for next step is full, don't add more tokens to it.
+                if beam_idx == self.num_beams:
+                    break
+
+            if beam_idx < self.num_beams:
+                raise ValueError(
+                    f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
+                )
+
+            # Check if we are done so that we can save a pad step if all(done)
+            self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
+                next_scores[batch_idx].max().item(), cur_len
+            )
+
+        return UserDict(
+            {
+                "next_beam_scores": next_beam_scores.view(-1),
+                "next_beam_tokens": next_beam_tokens.view(-1),
+                "next_beam_indices": next_beam_indices.view(-1),
+            }
+        )
+
+    def finalize(
+            self,
+            input_ids: torch.LongTensor,
+            final_beam_scores: torch.FloatTensor,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[int] = None,
+            mems=None
+    ) -> Tuple[torch.LongTensor, List[torch.Tensor], torch.FloatTensor]:
+        batch_size = len(self._beam_hyps)
+        # finalize all open beam hypotheses and add to generated hypotheses
+        for batch_idx, beam_hyp in enumerate(self._beam_hyps):
+            if self._done[batch_idx]:
+                continue
+
+            # need to add best num_beams hypotheses to generated hyps
+            for beam_id in range(self.num_beams):
+                batch_beam_idx = batch_idx * self.num_beams + beam_id
+                final_score = final_beam_scores[batch_beam_idx].item()
+                final_tokens = input_ids[batch_beam_idx]
+                beam_hyp.add(final_tokens, final_score, mems=[mem[[batch_beam_idx]] for mem in mems] if mems else None)
+
+        # select the best hypotheses
+        sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
+        best = []
+
+        # retrieve best hypotheses
+        for i, beam_hyp in enumerate(self._beam_hyps):
+            sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
+            for j in range(self.num_beam_hyps_to_keep):
+                score, best_hyp, mems = sorted_hyps.pop()
+                sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
+                best.append((best_hyp, mems, score))
+
+        # prepare for adding eos
+        sent_max_len = min(sent_lengths.max().item(), self.max_length)
+        decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
+        scores = final_beam_scores.new(batch_size * self.num_beam_hyps_to_keep)
+        # shorter batches are padded if needed
+        if sent_lengths.min().item() != sent_lengths.max().item():
+            assert pad_token_id is not None, "`pad_token_id` has to be defined"
+            decoded.fill_(pad_token_id)
+
+        # fill with hypotheses and eos_token_id if the latter fits in
+        mems = []
+        for i, (hypo, mem, score) in enumerate(best):
+            scores[i] = score
+            decoded[i, : sent_lengths[i]] = hypo
+            if sent_lengths[i] < sent_max_len:
+                decoded[i, sent_lengths[i]] = eos_token_id
+            mems.append(mem)
+        mems = [torch.cat([mem[i] for mem in mems], dim=0) for i in range(len(mems[0]))] if mems and mems[0] else None
+        return decoded, mems, scores
+
+
+class BeamHypotheses:
+    def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
+        """
+        Initialize n-best list of hypotheses.
+        """
+        self.max_length = max_length - 1  # ignoring bos_token
+        self.length_penalty = length_penalty
+        self.early_stopping = early_stopping
+        self.num_beams = num_beams
+        self.beams = []
+        self.worst_score = 1e9
+
+    def __len__(self):
+        """
+        Number of hypotheses in the list.
+        """
+        return len(self.beams)
+
+    def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None):
+        """
+        Add a new hypothesis to the list.
+        """
+        score = sum_logprobs / (max(hyp.shape[-1], 1) ** self.length_penalty)
+        if len(self) < self.num_beams or score > self.worst_score:
+            self.beams.append((score, hyp, mems))
+            if len(self) > self.num_beams:
+                sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
+                del self.beams[sorted_next_scores[0][1]]
+                self.worst_score = sorted_next_scores[1][0]
+            else:
+                self.worst_score = min(score, self.worst_score)
+
+    def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
+        """
+        If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
+        one in the heap, then we are done with this sentence.
+        """
+
+        if len(self) < self.num_beams:
+            return False
+        elif self.early_stopping:
+            return True
+        else:
+            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
+            ret = self.worst_score >= cur_score
+            return ret
+
+
+class LogitsProcessor(ABC):
+    """Abstract base class for all logit processors that can be applied during generation."""
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        """Torch method for processing logits."""
+        raise NotImplementedError(
+            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+        )
+
+
+class LogitsProcessorList(list):
+    """
+    This class can be used to create a list of :class:`~transformers.LogitsProcessor` or
+    :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from
+    list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or
+    :class:`~transformers.LogitsProcessor` to the inputs.
+    """
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        for processor in self:
+            scores = processor(input_ids, scores)
+        return scores
+
+
+class MinLengthLogitsProcessor(LogitsProcessor):
+    r"""
+    :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0.
+
+    Args:
+        min_length (:obj:`int`):
+            The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
+        eos_token_id (:obj:`int`):
+            The id of the `end-of-sequence` token.
+    """
+
+    def __init__(self, min_length: int, eos_token_ids: Union[List[int], int]):
+        if not isinstance(min_length, int) or min_length < 0:
+            raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
+        if isinstance(eos_token_ids, int):
+            eos_token_ids = [eos_token_ids]
+        for eos_token_id in eos_token_ids:
+            if not isinstance(eos_token_id, int) or eos_token_id < 0:
+                raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
+
+        self.min_length = min_length
+        self.eos_token_ids = eos_token_ids
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        cur_len = input_ids.shape[-1]
+        if cur_len < self.min_length:
+            for eos_token_id in self.eos_token_ids:
+                scores[:, eos_token_id] = -float("inf")
+        return scores
+
+
+class NoRepeatNGramLogitsProcessor(LogitsProcessor):
+    r"""
+    :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq
+    <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__.
+
+    Args:
+        ngram_size (:obj:`int`):
+            All ngrams of size :obj:`ngram_size` can only occur once.
+    """
+
+    def __init__(self, ngram_size: int):
+        if not isinstance(ngram_size, int) or ngram_size <= 0:
+            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
+        self.ngram_size = ngram_size
+
+    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+        num_batch_hypotheses = scores.shape[0]
+        cur_len = input_ids.shape[-1]
+        banned_batch_tokens = self._calc_banned_ngram_tokens(input_ids, num_batch_hypotheses, cur_len)
+
+        for i, banned_tokens in enumerate(banned_batch_tokens):
+            scores[i, banned_tokens] = -float("inf")
+
+        return scores
+
+    def _calc_banned_ngram_tokens(
+            self, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
+    ) -> List[Iterable[int]]:
+        """Copied from fairseq for no_repeat_ngram in beam_search"""
+        if cur_len + 1 < self.ngram_size:
+            # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
+            return [[] for _ in range(num_hypos)]
+        generated_ngrams = [{} for _ in range(num_hypos)]
+        for idx in range(num_hypos):
+            gen_tokens = prev_input_ids[idx].tolist()
+            generated_ngram = generated_ngrams[idx]
+            for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
+                prev_ngram_tuple = tuple(ngram[:-1])
+                generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
+
+        def _get_generated_ngrams(hypo_idx):
+            # Before decoding the next token, prevent decoding of ngrams that have already appeared
+            start_idx = cur_len + 1 - self.ngram_size
+            ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
+            return generated_ngrams[hypo_idx].get(ngram_idx, [])
+
+        banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
+        return banned_tokens
+
+
+class BeamSearchStrategy:
+    def __init__(self, num_beams, max_length, length_penalty, end_tokens, device='cuda', no_repeat_ngram_size=0,
+                 min_tgt_length=0):
+        self.num_beams = num_beams
+        self.max_length = max_length
+        self.length_penalty = length_penalty
+        self.end_tokens = end_tokens
+        self.no_repeat_ngram_size = no_repeat_ngram_size
+        self.min_tgt_length = min_tgt_length
+        self.processors = LogitsProcessorList()
+        if min_tgt_length > 0:
+            processor = MinLengthLogitsProcessor(min_tgt_length, self.end_tokens)
+            self.processors.append(processor)
+        if no_repeat_ngram_size > 0:
+            processor = NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)
+            self.processors.append(processor)
+        self.beam_scorer = BeamSearchScorer(
+            batch_size=1,
+            max_length=max_length,
+            num_beams=num_beams,
+            device=device,
+            length_penalty=length_penalty,
+            do_early_stopping=False,
+        )
+        self.beam_scores = torch.zeros(1, dtype=torch.float, device=device)
+
+    @property
+    def is_done(self) -> bool:
+        return self.beam_scorer.is_done
+
+    def forward(self, logits, tokens, mems):
+        last_beam_num = tokens.size(0)
+        logits = self.processors(tokens, logits)
+        next_token_scores = F.log_softmax(logits, dim=-1)
+        next_token_scores = next_token_scores + self.beam_scores[:, None].expand_as(next_token_scores)
+        vocab_size = next_token_scores.shape[-1]
+        next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)
+
+        probs = F.softmax(next_token_scores, dim=-1)
+        next_tokens = torch.multinomial(probs, num_samples=2 * self.num_beams)
+        next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
+        next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+        next_tokens = torch.gather(next_tokens, -1, _indices)
+
+        next_indices = next_tokens // vocab_size
+        next_tokens = next_tokens % vocab_size
+        # stateless
+        tokens = tokens.expand((self.num_beams, -1))
+        beam_outputs = self.beam_scorer.process(
+            tokens,
+            next_token_scores,
+            next_tokens,
+            next_indices,
+            eos_token_id=self.end_tokens,
+            mems=mems
+        )
+        self.beam_scores = beam_outputs["next_beam_scores"]
+        beam_next_tokens = beam_outputs["next_beam_tokens"]
+        beam_idx = beam_outputs["next_beam_indices"]
+        beam_next_tokens = beam_next_tokens.unsqueeze(-1)
+        tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
+        mems = [mem[beam_idx] for mem in mems] if mems else None
+        return tokens, mems
+
+    def finalize(self, tokens, mems):
+        tokens, mems, scores = self.beam_scorer.finalize(tokens, self.beam_scores,
+                                                         eos_token_id=self.end_tokens[0],
+                                                         mems=mems)
+        return tokens, mems
diff --git a/inference_cogview.py b/inference_cogview.py
index dd4c90f8934ab576f8cbc4e783a7cbe1eea065c8..f5ecb78d9d8ee061c8eb757569381d904e4ab953 100644
--- a/inference_cogview.py
+++ b/inference_cogview.py
@@ -37,8 +37,8 @@ def main(args):
     # define function for each query
     query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}'
     invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
-    strategy = BaseStrategy(invalid_slices, 
-        temperature=args.temperature, topk=args.top_k)
+    strategy = BaseStrategy(invalid_slices,
+                            temperature=args.temperature, top_k=args.top_k)
     
     def process(raw_text):
         if args.with_id:
diff --git a/inference_cogview2.py b/inference_cogview2.py
index 75ad1c6c269d4bd0eb8eb03a65e1b05f45d78c3b..c69e589d71009b40d3620f5052086dbb9509992b 100644
--- a/inference_cogview2.py
+++ b/inference_cogview2.py
@@ -42,8 +42,8 @@ def main(args):
     # define function for each query
     query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024 [EOI1]' if not args.full_query else '{}'
     invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
-    strategy0 = BaseStrategy(invalid_slices, 
-        temperature=args.temperature, topk=args.top_k)
+    strategy0 = BaseStrategy(invalid_slices,
+                             temperature=args.temperature, top_k=args.top_k)
     strategy1 = IterativeEntfilterStrategy(invalid_slices,
         temperature=args.temperature, topk=10) # temperature not used
     tr = transforms.Compose([
diff --git a/inference_glm.py b/inference_glm.py
new file mode 100644
index 0000000000000000000000000000000000000000..792efd5e030fa1564a7f52310b261844eb113c03
--- /dev/null
+++ b/inference_glm.py
@@ -0,0 +1,176 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_cogview.py
+@Time    :   2021/10/09 19:41:58
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import random
+import time
+from datetime import datetime
+import torch
+import torch.nn.functional as F
+
+import mpu
+from arguments import get_args
+from model.glm_model import GLMModel
+from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from generation.glm_sampling import filling_sequence_glm
+from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
+
+
+def read_context(tokenizer, args, output=None):
+    terminate_runs, skip_run = 0, 0
+    if mpu.get_model_parallel_rank() == 0:
+        while True:
+            raw_text = input("\nContext prompt (stop to exit) >>> ")
+            if not raw_text:
+                print('Prompt should not be empty!')
+                continue
+            if raw_text == "stop":
+                terminate_runs = 1
+                break
+            generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
+            if args.block_lm and 'MASK]' not in raw_text:
+                raw_text += ' ' + generation_mask
+            if output is not None:
+                output.write(raw_text)
+            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
+            if args.block_lm:
+                context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
+                if not raw_text.endswith('MASK]'):
+                    context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
+            context_length = len(context_tokens)
+
+            if context_length >= args.max_sequence_length:
+                print("\nContext length", context_length,
+                      "\nPlease give smaller context than the window length!")
+                continue
+            break
+    else:
+        context_length = 0
+
+    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
+    torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    terminate_runs = terminate_runs_tensor[0].item()
+
+    if terminate_runs == 1:
+        return terminate_runs, None, None, None
+
+    context_length_tensor = torch.cuda.LongTensor([context_length])
+
+    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    context_length = context_length_tensor[0].item()
+    if mpu.get_model_parallel_rank() == 0:
+        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+    else:
+        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
+    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    if mpu.get_model_parallel_rank() != 0:
+        raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
+    return terminate_runs, raw_text, context_tokens_tensor, context_length
+
+
+def get_batch(context_tokens, args):
+    tokens = context_tokens
+    tokens = tokens.view(1, -1).contiguous()
+    tokens = tokens.to('cuda')
+
+    # Get the masks and postition ids.
+    if args.block_lm:
+        attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
+        if args.fp16:
+            attention_mask = attention_mask.half()
+        position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
+        if not args.no_block_position:
+            block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
+            position_ids = torch.stack((position_ids, block_position_ids), dim=0)
+        position_ids = position_ids.unsqueeze(0)
+    else:
+        raise NotImplementedError
+
+    return tokens, attention_mask, position_ids
+
+
+def generate_samples(model, tokenizer, args):
+    model.eval()
+    output_path = "./samples"
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
+    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
+    with torch.no_grad(), open(output_path, "w") as output:
+        while True:
+            torch.distributed.barrier(group=mpu.get_model_parallel_group())
+            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
+            if terminate_runs == 1:
+                return
+            start_time = time.time()
+            if args.block_lm:
+                mems = []
+                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
+                mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
+                mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
+                end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
+                mask_positions = []
+                for token in mask_tokens:
+                    mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
+                mask_positions.sort()
+                if args.no_block_position:
+                    for mask_position in mask_positions:
+                        position_ids[0, mask_position + 1:] += args.out_seq_length
+                _, *mems = model(tokens, position_ids, attention_mask, *mems)
+                for mask_position in mask_positions:
+                    if args.no_block_position:
+                        position = position_ids[0, mask_position].item()
+                    else:
+                        position = mask_position
+                    if args.num_beams > 1:
+                        strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
+                                                      length_penalty=args.length_penalty, end_tokens=end_tokens,
+                                                      no_repeat_ngram_size=args.no_repeat_ngram_size,
+                                                      min_tgt_length=args.min_tgt_length)
+                    else:
+                        strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
+                                                end_tokens=end_tokens)
+                    new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
+                                                            end_tokens=end_tokens)
+                    tokens = torch.cat((tokens, new_tokens), dim=1)
+            output_tokens_list = tokens.view(-1).contiguous()
+            if mpu.get_model_parallel_rank() == 0:
+                os.system('clear')
+                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
+                print("\nContext:", raw_text, flush=True)
+                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
+                trim_decode_tokens = decode_tokens
+                print("\nGLM:", trim_decode_tokens, flush=True)
+                output.write(trim_decode_tokens + "\n")
+
+            torch.distributed.barrier(group=mpu.get_model_parallel_group())
+
+
+def main(args):
+    initialize_distributed(args)
+    tokenizer = prepare_tokenizer(args)
+    # build model
+    model = GLMModel(args)
+    if args.fp16:
+        model = model.half()
+    model = model.to(args.device)
+    load_checkpoint(model, args)
+    set_random_seed(args.seed)
+    model.eval()
+    generate_samples(model, tokenizer, args)
+
+
+if __name__ == "__main__":
+    args = get_args()
+
+    with torch.no_grad():
+        main(args)
diff --git a/model/glm_model.py b/model/glm_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..96502d1f7751228b752806f0b80376ece9410170
--- /dev/null
+++ b/model/glm_model.py
@@ -0,0 +1,18 @@
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .cached_autoregressive_model import CachedAutoregressiveModel
+
+
+class GLMModel(CachedAutoregressiveModel):
+    def __init__(self, args, transformer=None):
+        super().__init__(args, transformer=transformer)
+        self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size)
+        torch.nn.init.normal_(self.transformer.block_position_embeddings.weight, mean=0.0, std=0.02)
+
+    def position_embedding_forward(self, position_ids, *other_tensors):
+        position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
+        position_embeddings = self.transformer.position_embeddings(position_ids)
+        block_position_embeddings = self.transformer.block_position_embeddings(block_position_ids)
+        return position_embeddings + block_position_embeddings
diff --git a/mpu/transformer.py b/mpu/transformer.py
index 41963696fd24419dd350c1aa561383771e7d662a..cdc34a84fc0681178f7795cf9d3a43b1d92b257b 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -320,8 +320,6 @@ class BaseTransformer(torch.nn.Module):
         # sanity check 
         assert len(input_ids.shape) == 2 
         batch_size, query_length = input_ids.shape
-        assert len(position_ids.shape) <= 2
-        assert position_ids.shape[-1] == query_length
         assert len(attention_mask.shape) == 2 or \
             len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
         assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
@@ -340,6 +338,8 @@ class BaseTransformer(torch.nn.Module):
         if 'position_embedding_forward' in self.hooks:
             position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_tensors)
         else:
+            assert len(position_ids.shape) <= 2
+            assert position_ids.shape[-1] == query_length
             position_embeddings = self.position_embeddings(position_ids)    
         hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..143d7b5baa08029075b6c360f1b252a445b5a686
--- /dev/null
+++ b/scripts/generate_glm.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+CHECKPOINT_PATH=/dataset/fd5061f6/english_data/checkpoints
+
+source $1
+
+MPSIZE=1
+MAXSEQLEN=512
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+
+#SAMPLING ARGS
+TEMP=0.9
+#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
+TOPK=40
+TOPP=0
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+config_json="$script_dir/ds_config.json"
+
+python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_glm.py \
+       --mode inference \
+       --model-parallel-size $MPSIZE \
+       $MODEL_ARGS \
+       --num-beams 4 \
+       --no-repeat-ngram-size 3 \
+       --length-penalty 0.7 \
+       --fp16 \
+       --out-seq-length $MAXSEQLEN \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --top_p $TOPP
diff --git a/tokenization/__init__.py b/tokenization/__init__.py
index a6b481dc7ec60f9aed1acae5963726187cce81bb..f465ec4713f52132a32f198731a01825228613fc 100644
--- a/tokenization/__init__.py
+++ b/tokenization/__init__.py
@@ -13,24 +13,36 @@ import math
 import random
 import torch
 
+
 def get_tokenizer(args=None):
+    kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask,
+              "add_decoder_mask": args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0}
     if not hasattr(get_tokenizer, 'tokenizer'):
         # the first time to load the tokenizer
         if args.tokenizer_type == 'cogview':
             from .cogview import UnifiedTokenizer
             get_tokenizer.tokenizer = UnifiedTokenizer(
-                args.img_tokenizer_path, 
+                args.img_tokenizer_path,
                 device=torch.cuda.current_device()
-                )
-        elif args.tokenizer_type == 'glm':
-            pass
+            )
+        elif args.tokenizer_type == "BertWordPieceTokenizer":
+            from .text import BertWordPieceTokenizer
+            get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs)
+        elif args.tokenizer_type == "GPT2BPETokenizer":
+            from .text import GPT2BPETokenizer
+            get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
+        elif args.tokenizer_type == "ChineseSPTokenizer":
+            from .text import ChineseSPTokenizer
+            get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
     return get_tokenizer.tokenizer
 
+
 class FakeTokenizer(object):
     def __init__(self, num_tokens):
         self.num_tokens = num_tokens
+
     def __len__(self):
-        return self.num_tokens
\ No newline at end of file
+        return self.num_tokens
diff --git a/tokenization/file_utils.py b/tokenization/file_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4be142054687d5a88c3faac7499a469c6216715
--- /dev/null
+++ b/tokenization/file_utils.py
@@ -0,0 +1,250 @@
+# This file is provided as is from:
+#   https://github.com/huggingface/pytorch-pretrained-BERT
+# Please refer to their repository for copyright.
+
+"""
+Utilities for working with the local dataset cache.
+This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
+Copyright by the AllenNLP authors.
+"""
+from __future__ import (absolute_import, division, print_function, unicode_literals)
+
+import json
+import logging
+import os
+import shutil
+import tempfile
+from functools import wraps
+from hashlib import sha256
+import sys
+from io import open
+
+import boto3
+import requests
+from botocore.exceptions import ClientError
+from tqdm import tqdm
+
+from urllib.parse import urlparse
+
+try:
+    from pathlib import Path
+    PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
+                                                   Path.home() / '.pytorch_pretrained_bert'))
+except (AttributeError, ImportError):
+    PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
+                                              os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+def url_to_filename(url, etag=None):
+    """
+    Convert `url` into a hashed filename in a repeatable way.
+    If `etag` is specified, append its hash to the url's, delimited
+    by a period.
+    """
+    url_bytes = url.encode('utf-8')
+    url_hash = sha256(url_bytes)
+    filename = url_hash.hexdigest()
+
+    if etag:
+        etag_bytes = etag.encode('utf-8')
+        etag_hash = sha256(etag_bytes)
+        filename += '.' + etag_hash.hexdigest()
+
+    return filename
+
+
+def filename_to_url(filename, cache_dir=None):
+    """
+    Return the url and etag (which may be ``None``) stored for `filename`.
+    Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
+    """
+    if cache_dir is None:
+        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+        cache_dir = str(cache_dir)
+
+    cache_path = os.path.join(cache_dir, filename)
+    if not os.path.exists(cache_path):
+        raise EnvironmentError("file {} not found".format(cache_path))
+
+    meta_path = cache_path + '.json'
+    if not os.path.exists(meta_path):
+        raise EnvironmentError("file {} not found".format(meta_path))
+
+    with open(meta_path, encoding="utf-8") as meta_file:
+        metadata = json.load(meta_file)
+    url = metadata['url']
+    etag = metadata['etag']
+
+    return url, etag
+
+
+def cached_path(url_or_filename, cache_dir=None):
+    """
+    Given something that might be a URL (or might be a local path),
+    determine which. If it's a URL, download the file and cache it, and
+    return the path to the cached file. If it's already a local path,
+    make sure the file exists and then return the path.
+    """
+    if cache_dir is None:
+        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+    if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
+        url_or_filename = str(url_or_filename)
+    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+        cache_dir = str(cache_dir)
+
+    parsed = urlparse(url_or_filename)
+
+    if parsed.scheme in ('http', 'https', 's3'):
+        # URL, so get it from the cache (downloading if necessary)
+        return get_from_cache(url_or_filename, cache_dir)
+    elif os.path.exists(url_or_filename):
+        # File, and it exists.
+        return url_or_filename
+    elif parsed.scheme == '':
+        # File, but it doesn't exist.
+        raise EnvironmentError("file {} not found".format(url_or_filename))
+    else:
+        # Something unknown
+        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
+
+
+def split_s3_path(url):
+    """Split a full s3 path into the bucket name and path."""
+    parsed = urlparse(url)
+    if not parsed.netloc or not parsed.path:
+        raise ValueError("bad s3 path {}".format(url))
+    bucket_name = parsed.netloc
+    s3_path = parsed.path
+    # Remove '/' at beginning of path.
+    if s3_path.startswith("/"):
+        s3_path = s3_path[1:]
+    return bucket_name, s3_path
+
+
+def s3_request(func):
+    """
+    Wrapper function for s3 requests in order to create more helpful error
+    messages.
+    """
+
+    @wraps(func)
+    def wrapper(url, *args, **kwargs):
+        try:
+            return func(url, *args, **kwargs)
+        except ClientError as exc:
+            if int(exc.response["Error"]["Code"]) == 404:
+                raise EnvironmentError("file {} not found".format(url))
+            else:
+                raise
+
+    return wrapper
+
+
+@s3_request
+def s3_etag(url):
+    """Check ETag on S3 object."""
+    s3_resource = boto3.resource("s3")
+    bucket_name, s3_path = split_s3_path(url)
+    s3_object = s3_resource.Object(bucket_name, s3_path)
+    return s3_object.e_tag
+
+
+@s3_request
+def s3_get(url, temp_file):
+    """Pull a file directly from S3."""
+    s3_resource = boto3.resource("s3")
+    bucket_name, s3_path = split_s3_path(url)
+    s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
+
+
+def http_get(url, temp_file):
+    req = requests.get(url, stream=True)
+    content_length = req.headers.get('Content-Length')
+    total = int(content_length) if content_length is not None else None
+    progress = tqdm(unit="B", total=total)
+    for chunk in req.iter_content(chunk_size=1024):
+        if chunk: # filter out keep-alive new chunks
+            progress.update(len(chunk))
+            temp_file.write(chunk)
+    progress.close()
+
+
+def get_from_cache(url, cache_dir=None):
+    """
+    Given a URL, look for the corresponding dataset in the local cache.
+    If it's not there, download it. Then return the path to the cached file.
+    """
+    if cache_dir is None:
+        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+        cache_dir = str(cache_dir)
+
+    if not os.path.exists(cache_dir):
+        os.makedirs(cache_dir)
+
+    # Get eTag to add to filename, if it exists.
+    if url.startswith("s3://"):
+        etag = s3_etag(url)
+    else:
+        response = requests.head(url, allow_redirects=True)
+        if response.status_code != 200:
+            raise IOError("HEAD request failed for url {} with status code {}"
+                          .format(url, response.status_code))
+        etag = response.headers.get("ETag")
+
+    filename = url_to_filename(url, etag)
+
+    # get cache path to put the file
+    cache_path = os.path.join(cache_dir, filename)
+
+    if not os.path.exists(cache_path):
+        # Download to temporary file, then copy to cache dir once finished.
+        # Otherwise you get corrupt cache entries if the download gets interrupted.
+        with tempfile.NamedTemporaryFile() as temp_file:
+            logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
+
+            # GET file object
+            if url.startswith("s3://"):
+                s3_get(url, temp_file)
+            else:
+                http_get(url, temp_file)
+
+            # we are copying the file before closing it, so flush to avoid truncation
+            temp_file.flush()
+            # shutil.copyfileobj() starts at the current position, so go to the start
+            temp_file.seek(0)
+
+            logger.info("copying %s to cache at %s", temp_file.name, cache_path)
+            with open(cache_path, 'wb') as cache_file:
+                shutil.copyfileobj(temp_file, cache_file)
+
+            logger.info("creating metadata file for %s", cache_path)
+            meta = {'url': url, 'etag': etag}
+            meta_path = cache_path + '.json'
+            with open(meta_path, 'w', encoding="utf-8") as meta_file:
+                json.dump(meta, meta_file)
+
+            logger.info("removing temp file %s", temp_file.name)
+
+    return cache_path
+
+
+def read_set_from_file(filename):
+    '''
+    Extract a de-duped collection (set) of text from a file.
+    Expected file format is one item per line.
+    '''
+    collection = set()
+    with open(filename, 'r', encoding='utf-8') as file_:
+        for line in file_:
+            collection.add(line.rstrip())
+    return collection
+
+
+def get_file_extension(path, dot=True, lower=True):
+    ext = os.path.splitext(path)[1]
+    ext = ext if dot else ext[1:]
+    return ext.lower() if lower else ext
diff --git a/tokenization/text/__init__.py b/tokenization/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d191be6f3dd1515cc157a6bcb287e4bfd9078f57
--- /dev/null
+++ b/tokenization/text/__init__.py
@@ -0,0 +1 @@
+from .tokenization import BertWordPieceTokenizer, GPT2BPETokenizer, ChineseSPTokenizer
diff --git a/tokenization/text/sp_tokenizer.py b/tokenization/text/sp_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b6430eff576d487c68588ce6c94148904c1520d
--- /dev/null
+++ b/tokenization/text/sp_tokenizer.py
@@ -0,0 +1,150 @@
+"""
+from https://github.com/openai/gpt-2/, changed for chinese
+"""
+import json
+import os
+import sentencepiece as spm
+
+"""
+SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation 
+systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements 
+subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the 
+extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end 
+system that does not depend on language-specific pre/postprocessing.
+https://github.com/google/sentencepiece
+
+pip install sentencepiece
+
+or  git clone https://github.com/google/sentencepiece.git
+python setup.py install
+
+"""
+PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model"
+
+
+def get_pairs(word):
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+class Encoder:
+    def __init__(self, encoder, bpe_merges):
+        self.encoder = encoder
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+        self.cache = {}
+        self.max_len = 0
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token)
+        pairs = get_pairs(word)
+        if not pairs:
+            return token
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        return [self.encoder.get(token, 1) for token in self.tokenize(text)]
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        return text
+
+    def tokenize(self, text):
+        bpe_tokens = []
+        bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
+        return bpe_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        return [self.encoder.get(token, 1) for token in tokens]
+
+
+class Encoder_SP:
+    def __init__(self, model_path):
+        self.sp = spm.SentencePieceProcessor()
+        self.sp.Load(model_path)
+
+    def encode(self, text):
+        """
+        text="...."
+        """
+        return self.sp.EncodeAsIds(text)
+
+    def decode(self, tokens):
+        """
+        tokens=[x1,x2,...]
+        """
+        text = [int(token) for token in tokens]
+        # print(text)
+        return self.sp.DecodeIds(text)
+
+    def tokenize(self, text):
+        return self.sp.EncodeAsPieces(text)
+
+    def convert_tokens_to_ids(self, tokens):
+        return [self.sp.PieceToId(token) for token in tokens]
+
+    def convert_token_to_id(self, token):
+        return self.sp.PieceToId(token)
+
+    def convert_id_to_token(self, idx):
+        return self.sp.IdToPiece(idx)
+
+
+def get_encoder(encoder_file, bpe_file):
+    # 以下是为了同一个函数入兼容sentencepiece
+    filepath, filename = os.path.split(encoder_file)
+    shotname, extension = os.path.splitext(filename)
+
+    if (".model" == extension) and (bpe_file == ""):
+        return Encoder_SP(encoder_file)
+    else:
+        with open(encoder_file, 'r', encoding="utf-8") as f:
+            encoder = json.load(f)
+        with open(bpe_file, 'r', encoding="utf-8") as f:
+            bpe_data = f.read()
+        bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
+        return Encoder(
+            encoder=encoder,
+            bpe_merges=bpe_merges,
+        )
+
+
+def from_pretrained():
+    return get_encoder(PRETRAINED_MODEL_FILE, "")
diff --git a/tokenization/text/tokenization.py b/tokenization/text/tokenization.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d2009152f5ec796e92fa47b8ef9a84f9496d08
--- /dev/null
+++ b/tokenization/text/tokenization.py
@@ -0,0 +1,1254 @@
+# coding=utf-8
+# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)"""
+from collections import namedtuple
+import random
+import os
+import csv
+import torch
+import itertools
+
+import nltk
+from nltk import tokenize as nltk_tokenize
+import sentencepiece as spm
+
+from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
+
+from .tokenization_gpt2 import GPT2Tokenizer
+from . import sp_tokenizer
+import regex as re
+
+
+def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type=None, pad_token=0,
+                   character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
+    """
+    Helper function to instantiate a tokenizer given common combinations of options.
+    """
+    tokenizer_class = tokenizer_type
+    if isinstance(tokenizer_class, str):
+        tokenizer_class = eval(tokenizer_class)
+    if tokenizer_class is BertWordPieceTokenizer:
+        return BertWordPieceTokenizer(model_type, **kwargs)
+    elif tokenizer_class is GPT2BPETokenizer:
+        if model_type is None:
+            model_type = 'gpt2'
+        return GPT2BPETokenizer(model_type, **kwargs)
+    elif tokenizer_class is ChineseSPTokenizer:
+        return ChineseSPTokenizer(**kwargs)
+    text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type,
+                                     pad_token=pad_token, character_coverage=character_coverage)
+    return Tokenizer(text_tokenizer, command_tokens, type_tokens)
+
+
+class Tokenization(object):
+    """
+    Tokenization object to hold tokenization, (processed text),and original
+    text. Can hold tokenization as Ids or tokens.
+
+    It also holds command tokens (pad, unk, etc.) for the tokenization.
+    This allows functions to pad/operate on tokenizations without having
+    access to the full tokenizer, just the tokenization.
+
+    Several standard array operations are implemented (insert, append, extend).
+    """
+
+    def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True):
+        self.tokenization = tokenization
+        self.text = text
+        if self.text is None:
+            self.text = self.tokenization
+        self.original_text = original_text
+        if self.original_text is None:
+            self.original_text = self.text
+        self.command_tokens = command_tokens
+        self.asIds = asIds
+        self.parse_command_tokens()
+
+    def set_command_tokens(self, command_tokens):
+        self.command_tokens = command_tokens
+        return self.parse_command_tokens()
+
+    def parse_command_tokens(self):
+        if self.command_tokens is None:
+            return
+        for command_token in self.command_tokens:
+            if self.asIds:
+                setattr(self, command_token.name, command_token.Id)
+            else:
+                setattr(self, command_token.name, command_token.token)
+
+    def __getitem__(self, index):
+        return self.tokenization[index]
+
+    def __len__(self):
+        return len(self.tokenization)
+
+    def insert(self, idx, other):
+        if isinstance(other, (CommandToken, TypeToken)):
+            self.tokenization.insert(idx, other.Id)
+            if idx == 0:
+                self.text = other.token + self.text
+                self.original_text = other.token + self.original_text
+            elif idx == len(self.tokenization) - 1:
+                self.text += other.token
+                self.original_text += other.token
+        elif isinstance(other, Tokenization):
+            self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
+        else:
+            self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
+
+    def append(self, other):
+        if isinstance(other, (CommandToken, TypeToken)):
+            self.tokenization.append(other.Id)
+            self.text += other.token
+            self.original_text += other.token
+        elif isinstance(other, Tokenization):
+            self.tokenization.extend(other.tokenization)
+            self.text += other.text
+            self.original_text += other.original_text
+        else:
+            self.tokenization.append(other)
+        return self
+
+    def extend(self, other):
+        if isinstance(other, (CommandToken, TypeToken)):
+            self.tokenization.append(other.Id)
+            self.text += other.token
+            self.original_text += other.token
+        elif isinstance(other, list) and isinstance(other[0], (CommandToken, TypeToken)):
+            self.tokenization.extend([o.Id for o in other])
+            self.text += [o.token for o in other]
+            self.original_text += [o.token for o in other]
+        elif isinstance(other, Tokenization):
+            self.tokenization.extend(other.tokenization)
+            self.text += other.text
+            self.original_text += other.original_text
+        else:
+            self.tokenization.extend(other)
+        return self
+
+
+"""define some default command tokens for the tokenizer to use"""
+token_format = "<{0}>"
+
+COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
+
+
+def prep_command_tokens(tokenlist, token_format=token_format):
+    return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
+
+
+class CommandToken(object):
+    def __init__(self, name, token, Id, lstrip=False, rstrip=False):
+        self.name = name
+        self.token = token
+        self.Id = Id
+        self.lstrip = lstrip
+        self.rstrip = rstrip
+
+    def __str__(self):
+        return str(COMMAND_TUPLE(self.name, self.token, self.Id))
+
+
+DEFAULT_COMMAND_TOKENS = [
+    ('pad', 0),
+    ('eos', 1),
+    ('bos', 2),
+    ('unk', 3),
+    ('sep', 4),
+    ('L2R', 5),
+    ('ENC', 6),
+    ('MASK', 7),
+]
+DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
+
+"""define some default type tokens for bert training"""
+
+TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
+
+
+def prep_type_tokens(tokenlist, token_format=token_format):
+    return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
+
+
+class TypeToken(object):
+    def __init__(self, name, token, Id):
+        self.name = name
+        self.token = token
+        self.Id = Id
+
+    def __str__(self):
+        return str(TYPE_TUPLE(self.name, self.token, self.Id))
+
+
+DEFAULT_TYPE_TOKENS = [
+    ('function', 0),
+    ('command', 1),
+    ('str0', 2),
+    ('str1', 3),
+    ('str2', 4),
+    ('embedding0', 5),
+    ('embedding1', 6),
+    ('embedding2', 7),
+    ('arg0', 8),
+    ('arg1', 9),
+    ('arg2', 10),
+]
+DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
+
+
+class Tokenizer(object):
+    """
+    Tokenizer object that handles text tokenization, command tokens, and type tokens.
+
+    Command tokens and text tokens are stored together in one mapping of size
+    `len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first
+    `len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`.
+
+    Token types are stored in a separate mapping of size `len(type_tokens)`.
+    """
+
+    def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
+        # set text tokenizer
+        self.text_tokenizer = text_tokenizer
+        if not hasattr(self, 'num_text_tokens'):
+            self.num_text_tokens = len(self.text_tokenizer)
+
+        # set command tokens
+        if command_tokens is None:
+            command_tokens = DEFAULT_COMMAND_TOKENS
+        self._command_tokens = command_tokens
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+        if not hasattr(self, 'num_command_tokens'):
+            self.num_command_tokens = len(self._command_tokens)
+        if not hasattr(self, 'num_tokens'):
+            self.num_tokens = self.num_command_tokens + self.num_text_tokens
+
+        # set type tokens
+        if type_tokens is None:
+            type_tokens = DEFAULT_TYPE_TOKENS
+        self.type_tokens = type_tokens
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+        if not hasattr(self, 'num_type_tokens'):
+            self.num_type_tokens = len(self.type_tokens)
+
+        # parse tokens and vocabs from tokenizer
+        self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
+        self._vocab = {t: Id for Id, t in self.command_id_map.items()}
+        self._vocab.update({t: Id + self.num_command_tokens for t, Id in self.text_tokenizer.vocab.items()})
+
+        self._text_tokens = list(self.text_tokenizer.tokens)
+        self._text_token_vocab = {t: Id + self.num_command_tokens for t, Id in self.text_tokenizer.vocab.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+    def __call__(self, text, process_fn=None):
+        """run preprocessing and encode text as Ids"""
+        return self.EncodeAsIds(text, process_fn=process_fn)
+
+    def __len__(self):
+        """total number of tokens"""
+        return self.num_tokens
+
+    def get_command(self, name):
+        """get command token corresponding to `name`"""
+        return self.command_name_map[name]
+
+    def get_type(self, name):
+        """get type token corresponding to `name`"""
+        return self.type_name_map[name]
+
+    @property
+    def tokens(self):
+        """list (or iterable) of all tokens for tokenizer"""
+        return self._tokens
+
+    @property
+    def vocab(self):
+        """dictionary mapping tokens to ids for tokenizer"""
+        return self._vocab
+
+    @property
+    def token_types(self):
+        """list (or iterable) of all token types for tokenizer"""
+        return self._token_types
+
+    @property
+    def token_type_vocab(self):
+        """dictionary mapping token types to ids for tokenizer"""
+        return self._token_type_vocab
+
+    @property
+    def command_tokens(self):
+        """list (or iterable) of all command tokens for tokenizer"""
+        return self._command_token_tokens
+
+    @property
+    def command_token_vocab(self):
+        """dictionary mapping command tokens to ids for tokenizer"""
+        return self._command_token_vocab
+
+    @property
+    def text_tokens(self):
+        """list (or iterable) of text tokens for text tokenizer"""
+        return self._text_tokens
+
+    @property
+    def text_token_vocab(self):
+        """dictionary mapping text tokens to ids for text tokenizer"""
+        return self._text_token_vocab
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """
+        encode text using text tokenizer and shift Id values for command tokens
+        """
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+
+        def split_on_token(tok_extended: CommandToken, text):
+            result = []
+            tok = tok_extended.token
+            split_text = text.split(tok)
+            for i, sub_text in enumerate(split_text):
+                # CommandToken can control whitespace stripping around them.
+                # We use them for GPT2 and Roberta to have different behavior depending on the special token
+                # Cf. https://github.com/huggingface/transformers/pull/2778
+                # and https://github.com/huggingface/transformers/issues/3788
+                # Strip white spaces on the right
+                if tok_extended.rstrip and i > 0:
+                    # A bit counter-intuitive but we strip the left of the string
+                    # since tok_extended.rstrip means the special token is eating all white spaces on its right
+                    sub_text = sub_text.lstrip()
+                # Strip white spaces on the left
+                if tok_extended.lstrip and i < len(split_text) - 1:
+                    sub_text = sub_text.rstrip()  # Opposite here
+
+                if i == 0 and not sub_text:
+                    result.append(tok)
+                elif i == len(split_text) - 1:
+                    if sub_text:
+                        result.append(sub_text)
+                    else:
+                        pass
+                else:
+                    if sub_text:
+                        result.append(sub_text)
+                    result.append(tok)
+            return result
+
+        def split_on_tokens(tok_list, text):
+            if not text.strip():
+                return []
+            if not tok_list:
+                return self.text_tokenizer.encode(text)
+
+            tokenized_text = []
+            text_list = [text]
+            for tok in tok_list:
+                tokenized_text = []
+                for sub_text in text_list:
+                    if sub_text not in self._command_token_tokens:
+                        tokenized_text.extend(split_on_token(tok, sub_text))
+                    else:
+                        tokenized_text.append(sub_text)
+                text_list = tokenized_text
+
+            return list(
+                itertools.chain.from_iterable(
+                    (
+                        self._encode(token) if token not in self._command_token_tokens else [
+                            self.command_token_map[token].Id] for token in tokenized_text
+                    )
+                )
+            )
+
+        no_split_tokens = self._command_tokens
+        Ids = split_on_tokens(no_split_tokens, processed_text)
+        tokenization = Tokenization(Ids, processed_text, text)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def _encode(self, text):
+        raise NotImplementedError
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """
+        encode text as tokens using text tokenizer
+        """
+        tokenization = self.text_tokenizer.EncodeAsTokens(text, process_fn=process_fn)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def IdToToken(self, Id, type_token=False):
+        """convert Id to token accounting for command and type tokens"""
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id < self.num_command_tokens:
+            return self.command_id_map[Id].token
+        return self.text_tokenizer.IdToToken(Id - self.num_command_tokens)
+
+    def TokenToId(self, token, type_token=False):
+        """convert token to Id accounting for command and type tokens"""
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        if token in self.command_token_map:
+            return self.command_token_map[token].Id
+        return self.text_tokenizer.TokenToId(token) + self.num_command_tokens
+
+    def DecodeIds(self, Ids, type_token=False):
+        """
+        convert Ids to tokens accounting for command and type tokens, tokens
+        are joined and returned as a string.
+        """
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        rtn_strs = []
+        current_str = []
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        for Id in Ids:
+            if isinstance(Id, CommandToken):
+                rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+                current_str = []
+                rtn_strs.append(Id.token)
+            elif Id < self.num_command_tokens:
+                rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+                current_str = []
+                rtn_strs.append(self.command_id_map[Id].token)
+            else:
+                current_str.append(Id - self.num_command_tokens)
+        if current_str != []:
+            rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+        return ' '.join(rtn_strs)
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        """
+        convert tokens to a string accounting for command and type tokens.
+        """
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        rtn_strs = []
+        current_str = []
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        for t in Tokens:
+            if isinstance(t, CommandToken):
+                rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+                current_str = []
+                rtn_strs.append(t.token)
+            elif t in self.command_token_map:
+                rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+                current_str = []
+                rtn_strs.append(t)
+            else:
+                current_str.append(t)
+        if current_str != []:
+            rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+        return ' '.join(rtn_strs)
+
+
+class TextTokenizer(object):
+    """
+    Interface for text tokenizer
+    """
+
+    def __init__(self):
+        if not hasattr(self, 'num_text_tokens'):
+            self.num_text_tokens = 0
+        if not hasattr(self, 'num_tokens'):
+            self.num_tokens = self.num_text_tokens
+
+    def __call__(self, text, process_fn=None):
+        return self.EncodeAsIds(text, process_fn)
+
+    def __len__(self):
+        return self.num_text_tokens
+
+    @property
+    def tokens(self):
+        """list (or iterable) of text tokens for text tokenizer"""
+        raise NotImplementedError('TextTokenizer tokens property not implemented')
+
+    @property
+    def vocab(self):
+        """dictionary mapping tokens to ids"""
+        raise NotImplementedError('TextTokenizer vocab property not implemented')
+
+    @staticmethod
+    def exists(model_path):
+        """check if the filepath for a text tokenizer exists"""
+        raise NotImplementedError('TextTokenizer exists method not implemented')
+
+    def Train(self, corpus):
+        """train a tokenizer on a data corpus and save model for future use"""
+        raise NotImplementedError('TextTokenizer Train not implemented')
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """
+        Preprocess text and encode as ids. Return a tokenization object with
+        original text, processed text, and id tokenization.
+        """
+        raise NotImplementedError('TextTokenizer EncodeAsIds not implemented')
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """
+        Preprocess text and encode as tokens. Return a tokenization object with
+        original text, processed text, and token tokenization.
+        """
+        raise NotImplementedError('TextTokenizer EncodeAsTokens not implemented')
+
+    def IdToToken(self, Id):
+        """Convert an Id to Token. Reverse lookup of self.vocab"""
+        raise NotImplementedError('TextTokenizer IdToToken not implemented')
+
+    def TokenToId(self, token):
+        """Convert a Token to Id. Lookup of self.vocab"""
+        raise NotImplementedError('TextTokenizer TokenToId not implemented')
+
+    def DecodeIds(self, Ids):
+        """Convert a list or tokenization object of Ids to a text string"""
+        raise NotImplementedError('TextTokenizer DecodeIds not implemented')
+
+    def DecodeTokens(self, Tokens):
+        """Convert a list or tokenization object of tokens to a text string"""
+        raise NotImplementedError('TextTokenizer DecodeTokens not implemented')
+
+
+class CharacterLevelTokenizer(TextTokenizer):
+    """
+    Text tokenizer for ASCII-256 Character Level Tokenization.
+    """
+
+    def __init__(self, **kwargs):
+        self.num_text_tokens = 256
+        super(CharacterLevelTokenizer, self).__init__()
+        self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
+        self._vocab = {t: i for i, t in enumerate(self._tokens)}
+
+    def __len__(self):
+        return 256
+
+    @staticmethod
+    def exists(model_path):
+        return True
+
+    def Train(self, corpus):
+        pass
+
+    @property
+    def tokens(self):
+        return self._tokens
+
+    @property
+    def vocab(self):
+        return self._vocab
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """convert text to ascii 256 Ids"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+            processed_text = str(processed_text)
+        tokens = [self.TokenToId(c) for c in processed_text]
+        return Tokenization(tokens, processed_text, text)
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """convert text to ascii 256 characters"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        processed_text = str(processed_text)
+        tokens = [c for c in processed_text]
+        return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id):
+        """ascii index to character"""
+        return chr(Id)
+
+    def TokenToId(self, token):
+        """ascii character to index"""
+        return ord(token)
+
+    def DecodeIds(self, Ids):
+        """converts ascii ids to tokens before joining them into text"""
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        return ''.join([self.IdToToken(tok) for tok in Ids])
+
+    def DecodeTokens(self, Tokens):
+        """just concatenates ascii tokens into text"""
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return ''.join(Tokens)
+
+
+MAX_SENTENCEPIECE_SENTENCES = 100000000
+
+
+def get_corpus_freq(dataset, filepath, filetype='tsv'):
+    """
+    Take corpus, split it into sentences, and extract word frequencies.
+    Write frequencies to `filepath` as a tsv. Only write the first
+    MAX_SENTENCEPIECE_SENTENCES most common words to the file.
+    """
+    nltk.download('punkt', download_dir="./nltk")
+    if filetype == 'tsv':
+        delimiter = '\t'
+    else:
+        delimiter = ','
+
+    print("compute corpus frequency\n", flush=True)
+
+    total_sentence_count = 0
+    maxlen = 0
+    freqs = {}
+    for entry in dataset:
+        if isinstance(entry, dict):
+            entry = entry['text']
+        lines = entry.strip().split('\n')
+        for line in lines:
+            sentences = nltk_tokenize.sent_tokenize(line)
+            total_sentence_count += len(sentences)
+            for sentence in sentences:
+                maxlen = max(len(line), maxlen)
+                for word in sentence.split():
+                    if word not in freqs:
+                        freqs[word] = 0
+                    freqs[word] += 1
+
+    print("length of freqs before truncating " + str(len(freqs)), flush=True)
+    print("file path for freq " + str(filepath), flush=True)
+
+    freqs_sorted = {}
+    counter = 0
+    for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
+        if counter >= MAX_SENTENCEPIECE_SENTENCES:
+            break
+        counter += 1
+        freqs_sorted[word] = count
+
+    print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
+
+    with open(filepath, 'w') as f:
+        writer = csv.writer(f, delimiter=delimiter)
+        for k, v in freqs_sorted.items():
+            writer.writerow([str(k), str(v)])
+
+    return total_sentence_count, maxlen
+
+
+class SentencePieceTokenizer(TextTokenizer):
+    """Trains and uses sentencepiece for text tokenization"""
+
+    def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0,
+                 **kwargs):
+        self.character_coverage = character_coverage
+        self.model_type = model_type.lower()
+        self.spm_model = model_path
+        self.num_text_tokens = vocab_size
+        make_train = not SentencePieceTokenizer.exists(self.spm_model)
+        if make_train:
+            assert corpus is not None and self.num_text_tokens is not None
+            self.Train(corpus, self.num_text_tokens)
+        self._tokens = []
+        self._vocab = {}
+        self.load_spm_model()
+        super(SentencePieceTokenizer, self).__init__()
+
+    def __len__(self):
+        return self.num_text_tokens
+
+    @property
+    def tokens(self):
+        return self._tokens
+
+    @property
+    def vocab(self):
+        return self._vocab
+
+    @staticmethod
+    def exists(model_path):
+        if model_path is None:
+            return False
+        # check if path exists
+        dne = not os.path.exists(model_path)
+        # check if path.model exists
+        if dne and not model_path.endswith('.model'):
+            dne = not os.path.exists(model_path + '.model')
+        return not dne
+
+    def load_spm_model(self):
+        """load sentencepiece model and parse vocab"""
+        if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
+            self.spm_model = self.spm_model + '.model'
+        self.sp = spm.SentencePieceProcessor()
+        self.sp.Load(self.spm_model)
+        self.vocab_size = self.num_text_tokens = len(self.sp)
+        self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
+        self._vocab = {t: i for i, t in enumerate(self._tokens)}
+
+    def Train(self, corpus, num_text_tokens):
+        """train sentencepiece model on corpus using word frequencies"""
+        self.num_text_tokens = num_text_tokens
+        use_model_path = self.spm_model
+        random_hash = str(random.randint(0, 2147483647))
+        if use_model_path is None:
+            use_model_path = random_hash
+        if use_model_path.endswith('.model'):
+            use_model_path = use_model_path[:use_model_path.rfind('.model')]
+        input_path = use_model_path + '.tsv.' + random_hash
+        line_count, maxlenline = get_corpus_freq(corpus, input_path)
+        line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
+        print('line count used as input_sentence_size ', line_count, flush=True)
+        print('training sentencepiece model', flush=True)
+        train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \
+                       + ' --model_type={model_type} --character_coverage={character_coverage} ' \
+                       + '--input_sentence_size={input_sentence_size} ' \
+                       + '--input_format=tsv'
+        train_string = train_string.format(file_path=input_path, model_prefix=use_model_path,
+                                           vocab_size=num_text_tokens,
+                                           model_type=self.model_type, character_coverage=self.character_coverage,
+                                           input_sentence_size=int(line_count))  # , #)#,
+        print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True)
+        spm.SentencePieceTrainer.Train(train_string)
+        os.remove(input_path)
+        self.spm_model = use_model_path + '.model'
+        print('sentencepiece model written to ' + self.spm_model, flush=True)
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """convert text to sentencepiece Ids"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.sp.EncodeAsIds(processed_text)
+        return Tokenization(tokens, processed_text, text)
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """convert text to sentencepiece tokens"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.sp.EncodeAsTokens(processed_text)
+        return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id):
+        """convert Id to sentencpiece token"""
+        return self.sp.IdToPiece(Id)
+
+    def TokenToId(self, token):
+        """convert sentencpiece token to Id"""
+        return self.sp.PieceToId(token)
+
+    def DecodeIds(self, Ids):
+        """converts ids to a text string"""
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        return self.sp.DecodeIds(Ids)
+
+    def DecodeTokens(self, Tokens):
+        """converts sentencepiece tokens to a text string"""
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return self.sp.DecodeTokens(Tokens)
+
+
+class BertWordPieceTokenizer(Tokenizer):
+    """
+    Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
+    in BERT training. Default to bert-large-uncased tokenizer.
+    """
+
+    def __init__(self, tokenizer_model_type=None, cache_dir=None, add_block_symbols=False, add_sentinel_token=0,
+                 add_task_mask=False, add_decoder_mask=False, **kwargs):
+        # default to bert-large-uncased tokenizer
+        if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
+            tokenizer_model_type = 'bert-large-uncased'
+        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+            print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir)
+        do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
+        self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case,
+                                                            cache_dir=cache_dir)
+        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+            print('loaded', tokenizer_model_type)
+        # disable max len warnings by increasing max len
+        self.text_tokenizer.max_len = int(1e12)
+
+        # set command tokens from wordpiece tokenizer values
+        self.num_command_tokens = 6
+        self.num_tokens = len(self.text_tokenizer.vocab)
+        self.num_text_tokens = self.num_tokens - 5
+        self.num_type_tokens = 2
+
+        self._command_tokens = [
+            CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']),
+            CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']),
+            CommandToken('MASK', '[MASK]', self.text_tokenizer.vocab['[MASK]']),
+            CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']),
+            CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']),
+            CommandToken('eos', '[PAD]', self.text_tokenizer.vocab['[PAD]']),
+        ]
+        if add_block_symbols:
+            self._command_tokens.extend([
+                CommandToken('sop', '<|startofpiece|>', self.num_tokens),
+                CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1)
+            ])
+            self.num_tokens += 2
+            self.num_command_tokens += 2
+            if add_task_mask:
+                self._command_tokens.extend([
+                    CommandToken('gMASK', '[gMASK]', self.num_tokens),
+                    CommandToken('sMASK', '[sMASK]', self.num_tokens + 1)
+                ])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+            if add_decoder_mask:
+                self._command_tokens.extend([
+                    CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)
+                ])
+                self.num_tokens += 1
+                self.num_command_tokens += 1
+        if add_sentinel_token > 0:
+            for i in range(1, add_sentinel_token):
+                self._command_tokens.extend([CommandToken(f'MASK{i}', f'[MASK{i}]', self.num_tokens),
+                                             CommandToken(f'sop{i}', f'<|startofpiece{i}|>', self.num_tokens + 1)])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+        # set type tokens
+        self.type_tokens = [
+            TypeToken('str0', '<str0>', 0),
+            TypeToken('str1', '<str1>', 1),
+        ]
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+        # parse tokens and vocabs from tokenizer
+
+        self._tokens = list(self.text_tokenizer.vocab.keys())
+        self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
+
+        self._text_tokens = list(self._tokens)
+        self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+    def _encode(self, text):
+        tokens = self.text_tokenizer.tokenize(text)
+        ids = self.text_tokenizer.convert_tokens_to_ids(tokens)
+        return ids
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """convert wordpiece token to Id"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.text_tokenizer.tokenize(processed_text)
+        return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id, type_token=False):
+        """convert Id to sentencpiece token"""
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id in self.command_id_map:
+            return self.command_id_map[Id].token
+        return self.text_tokenizer.ids_to_tokens[Id]
+
+    @staticmethod
+    def clean_up_tokenization(out_string: str) -> str:
+        """
+        Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
+
+        Args:
+            out_string (:obj:`str`): The text to clean up.
+
+        Returns:
+            :obj:`str`: The cleaned-up string.
+        """
+        out_string = (
+            out_string.replace(" .", ".")
+                .replace(" ?", "?")
+                .replace(" !", "!")
+                .replace(" ,", ",")
+                .replace(" ' ", "'")
+                .replace(" n't", "n't")
+                .replace(" 'm", "'m")
+                .replace(" 's", "'s")
+                .replace(" 've", "'ve")
+                .replace(" 're", "'re")
+        )
+        return out_string
+
+    def TokenToId(self, token, type_token=False):
+        """convert sentencpiece token to Id"""
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        return self.text_tokenizer.vocab[token]
+
+    def DecodeIds(self, Ids, type_token=False):
+        """converts ids to wordpiece tokens and joins them as a text string"""
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        Tokens = []
+        for Id in Ids:
+            if Id in self.command_id_map:
+                Tokens.append(self.command_id_map[Id].token)
+            elif Id in self.text_tokenizer.ids_to_tokens:
+                Tokens.append(self.text_tokenizer.ids_to_tokens[Id])
+        new_tokens = []
+        for token in Tokens:
+            if token.startswith('##') and len(new_tokens) > 0:
+                new_tokens[-1] += token[2:]
+            else:
+                new_tokens.append(token)
+        output = ' '.join(new_tokens)
+        output = self.clean_up_tokenization(output)
+        return output
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        """converts wordpiece tokens to a text string"""
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return ' '.join(Tokens)
+
+
+class GPT2BPETokenizer(Tokenizer):
+    def __init__(self, model_type_or_path, cache_dir=None, add_block_symbols=False, add_task_mask=False,
+                 add_decoder_mask=False, **kwargs):
+        self.text_tokenizer = GPT2Tokenizer.from_pretrained(model_type_or_path,
+                                                            cache_dir=cache_dir)
+
+        # disable max len warnings by increasing max len
+        self.text_tokenizer.max_len = int(1e12)
+        self.num_tokens = len(self.text_tokenizer.encoder)
+        self.num_type_tokens = 2
+        if model_type_or_path.startswith('roberta'):
+            self.num_command_tokens = 6
+            self.num_text_tokens = self.num_tokens - 3
+            self._command_tokens = [
+                CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['</s>']),
+                CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['</s>']),
+                CommandToken('sep', '[SEP]', self.text_tokenizer.encoder['<pad>']),
+                CommandToken('ENC', '[CLS]', self.text_tokenizer.encoder['<s>']),
+                CommandToken('MASK', '[MASK]', self.text_tokenizer.encoder['<mask>'], lstrip=True),
+                CommandToken('unk', '[UNK]', self.text_tokenizer.encoder['<unk>'])
+            ]
+            if add_block_symbols:
+                self._command_tokens.extend([
+                    CommandToken('sop', '<|startofpiece|>', self.num_tokens),
+                    CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1)
+                ])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+        else:
+            self.num_command_tokens = 2
+            self.num_text_tokens = self.num_tokens - 1
+            self._command_tokens = [
+                CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']),
+                CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>'])
+            ]
+            if add_block_symbols:
+                self._command_tokens.extend([
+                    CommandToken('sop', '<|startofpiece|>', self.num_tokens),
+                    CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1),
+                    CommandToken('ENC', '[CLS]', self.num_tokens + 2),
+                    CommandToken('MASK', '[MASK]', self.num_tokens + 3, lstrip=True),
+                    CommandToken('sep', '[SEP]', self.num_tokens + 4),
+                    CommandToken('unk', '[UNK]', self.num_tokens + 5)
+                ])
+                self.num_tokens += 6
+                self.num_command_tokens += 6
+        if add_block_symbols:
+            if add_task_mask:
+                self._command_tokens.extend([
+                    CommandToken('gMASK', '[gMASK]', self.num_tokens, lstrip=True),
+                    CommandToken('sMASK', '[sMASK]', self.num_tokens + 1, lstrip=True)
+                ])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+            if add_decoder_mask:
+                self._command_tokens.extend([
+                    CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)
+                ])
+                self.num_tokens += 1
+                self.num_command_tokens += 1
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+        self.type_tokens = [
+            TypeToken('str0', '<str0>', 0),
+            TypeToken('str1', '<str1>', 1),
+        ]
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+        self._tokens = list(self.text_tokenizer.encoder.keys())
+        self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
+
+        self._text_tokens = list(self._tokens)
+        self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+        for idx, tok in self.command_id_map.items():
+            self.text_tokenizer.decoder[idx] = tok.token
+
+    def EncodeAsIds(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+
+        def split_on_token(tok_extended: CommandToken, text):
+            result = []
+            tok = tok_extended.token
+            split_text = text.split(tok)
+            for i, sub_text in enumerate(split_text):
+                # CommandToken can control whitespace stripping around them.
+                # We use them for GPT2 and Roberta to have different behavior depending on the special token
+                # Cf. https://github.com/huggingface/transformers/pull/2778
+                # and https://github.com/huggingface/transformers/issues/3788
+                # Strip white spaces on the right
+                if tok_extended.rstrip and i > 0:
+                    # A bit counter-intuitive but we strip the left of the string
+                    # since tok_extended.rstrip means the special token is eating all white spaces on its right
+                    sub_text = sub_text.lstrip()
+                # Strip white spaces on the left
+                if tok_extended.lstrip and i < len(split_text) - 1:
+                    sub_text = sub_text.rstrip()  # Opposite here
+
+                if i == 0 and not sub_text:
+                    result.append(tok)
+                elif i == len(split_text) - 1:
+                    if sub_text:
+                        result.append(sub_text)
+                    else:
+                        pass
+                else:
+                    if sub_text:
+                        result.append(sub_text)
+                    result.append(tok)
+            return result
+
+        def split_on_tokens(tok_list, text):
+            if not text.strip():
+                return []
+            if not tok_list:
+                return self.text_tokenizer.encode(text)
+
+            tokenized_text = []
+            text_list = [text]
+            for tok in tok_list:
+                tokenized_text = []
+                for sub_text in text_list:
+                    if sub_text not in self._command_token_tokens:
+                        tokenized_text.extend(split_on_token(tok, sub_text))
+                    else:
+                        tokenized_text.append(sub_text)
+                text_list = tokenized_text
+
+            return list(
+                itertools.chain.from_iterable(
+                    (
+                        self.text_tokenizer.encode(token) if token not in self._command_token_tokens else [
+                            self.command_token_map[token].Id] for token in tokenized_text
+                    )
+                )
+            )
+
+        no_split_tokens = self._command_tokens
+        Ids = split_on_tokens(no_split_tokens, processed_text)
+        tokenization = Tokenization(Ids, processed_text, text)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def _encode(self, text):
+        return self.text_tokenizer.encode(text)
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = []
+        for token in re.findall(self.text_tokenizer.pat, processed_text):
+            token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
+            tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
+        tokenization = Tokenization(tokens, processed_text, text, asIds=False)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def DecodeAsTokens(self, Ids):
+        return [self.IdToToken(x) for x in Ids]
+
+    def IdToToken(self, Id, type_token=False):
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id in self.command_id_map:
+            return self.command_id_map[Id].token
+        return self.text_tokenizer.decoder[Id]
+
+    def TokenToId(self, token, type_token=False):
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        return self.text_tokenizer.encoder[token]
+
+    def DecodeIds(self, Ids, type_token=False):
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        return self.text_tokenizer.decode(Ids)
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
+
+
+class ChineseSPTokenizer(Tokenizer):
+    def __init__(self, add_block_symbols=False, **kwargs):
+        self.text_tokenizer = sp_tokenizer.from_pretrained()
+
+        self.num_command_tokens = 2
+        self.num_text_tokens = self.text_tokenizer.sp.vocab_size()
+        self.num_tokens = self.num_text_tokens + 1
+        self.num_type_tokens = 2
+
+        self._command_tokens = [
+            CommandToken('pad', '<|endoftext|>', self.num_text_tokens),
+            CommandToken('eos', '<|endoftext|>', self.num_text_tokens),
+        ]
+        if add_block_symbols:
+            self._command_tokens.extend([
+                CommandToken('sop', '<|startofpiece|>', self.num_text_tokens + 1),
+                CommandToken('eop', '<|endofpiece|>', self.num_text_tokens + 2)
+            ])
+            self.num_tokens += 2
+            self.num_command_tokens += 2
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+        self.type_tokens = [
+            TypeToken('str0', '<str0>', 0),
+            TypeToken('str1', '<str1>', 1),
+        ]
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+        # self._tokens = list(self.text_tokenizer.encoder.keys())
+        # self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
+        #
+        # self._text_tokens = list(self._tokens)
+        # self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+    def _encode(self, text):
+        ids = self.text_tokenizer.encode(text)
+        return ids
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.text_tokenizer.tokenize(processed_text)
+        tokenization = Tokenization(tokens, processed_text, text, asIds=False)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+        # return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id, type_token=False):
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id in self.command_id_map:
+            return self.command_id_map[Id].token
+        elif Id in self.type_id_map:
+            return self.type_id_map[Id].token
+        else:
+            return self.text_tokenizer.convert_id_to_token(Id)
+
+    def TokenToId(self, token, type_token=False):
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        return self.text_tokenizer.convert_token_to_id(token)
+
+    def DecodeIds(self, Ids, type_token=False):
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        try:
+            first_eos = Ids.index(self.get_command('eos').Id)
+            eos_count = len(Ids) - first_eos
+            Ids = Ids[:first_eos]
+        except ValueError:
+            eos_count = 0
+        return " ".join((self.text_tokenizer.decode(Ids), *(['<|endoftext|>'] * eos_count)))
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
diff --git a/tokenization/text/tokenization_gpt2.py b/tokenization/text/tokenization_gpt2.py
new file mode 100644
index 0000000000000000000000000000000000000000..318d9209990b74b9aaadc13d1f482923c77bc3ab
--- /dev/null
+++ b/tokenization/text/tokenization_gpt2.py
@@ -0,0 +1,310 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+from __future__ import (absolute_import, division, print_function,
+                        unicode_literals)
+
+import sys
+import json
+import logging
+import os
+import regex as re
+from io import open
+
+try:
+    from functools import lru_cache
+except ImportError:
+    # Just a dummy decorator to get the checks to run on python2
+    # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
+    def lru_cache():
+        return lambda func: func
+
+from ..file_utils import cached_path
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+    'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-vocab.json",
+    "roberta": "pretrained/pytorch_pretrained_bert/roberta-vocab.json"
+}
+PRETRAINED_MERGES_ARCHIVE_MAP = {
+    'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-merges.txt",
+    "roberta": "pretrained/pytorch_pretrained_bert/roberta-merges.txt"
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+    'gpt2': 1024,
+}
+VOCAB_NAME = 'vocab.json'
+MERGES_NAME = 'merges.txt'
+SPECIAL_TOKENS_NAME = 'special_tokens.txt'
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    _chr = unichr if sys.version_info[0] == 2 else chr
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [_chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+class GPT2Tokenizer(object):
+    """
+    GPT-2 BPE tokenizer. Peculiarities:
+        - Byte-level BPE
+    """
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+        """
+        Instantiate a PreTrainedBertModel from a pre-trained model file.
+        Download and cache the pre-trained model file if needed.
+        """
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+            vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+            merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
+            special_tokens_file = None
+        else:
+            vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
+            merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
+            special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
+            if not os.path.exists(special_tokens_file):
+                special_tokens_file = None
+            else:
+                logger.info("loading special tokens file {}".format(special_tokens_file))
+        # redirect to the cache, if necessary
+        # try:
+        #     resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+        #     resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
+        # except EnvironmentError:
+        #     logger.error(
+        #         "Model name '{}' was not found in model name list ({}). "
+        #         "We assumed '{}' was a path or url but couldn't find files {} and {} "
+        #         "at this path or url.".format(
+        #             pretrained_model_name_or_path,
+        #             ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+        #             pretrained_model_name_or_path,
+        #             vocab_file, merges_file))
+        #     return None
+        # if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
+        #     logger.info("loading vocabulary file {}".format(vocab_file))
+        #     logger.info("loading merges file {}".format(merges_file))
+        # else:
+        #     logger.info("loading vocabulary file {} from cache at {}".format(
+        #         vocab_file, resolved_vocab_file))
+        #     logger.info("loading merges file {} from cache at {}".format(
+        #         merges_file, resolved_merges_file))
+        resolved_vocab_file = vocab_file
+        resolved_merges_file = merges_file
+        logger.info("loading vocabulary file {}".format(vocab_file))
+        logger.info("loading merges file {}".format(merges_file))
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
+            # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+            # than the number of positional embeddings
+            max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
+            kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
+        # Instantiate tokenizer.
+        if special_tokens_file and 'special_tokens' not in kwargs:
+            special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
+        else:
+            special_tokens = kwargs.pop('special_tokens', [])
+        tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
+        return tokenizer
+
+    def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
+        self.max_len = max_len if max_len is not None else int(1e12)
+        self.encoder = json.load(open(vocab_file))
+        self.decoder = {v:k for k,v in self.encoder.items()}
+        self.errors = errors # how to handle errors in decoding
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
+        bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
+        bpe_merges = [tuple(merge.split()) for merge in bpe_data]
+        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+        self.cache = {}
+
+        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+        self.special_tokens = {}
+        self.special_tokens_decoder = {}
+        self.set_special_tokens(special_tokens)
+
+    def __len__(self):
+        return len(self.encoder) + len(self.special_tokens)
+
+    def set_special_tokens(self, special_tokens):
+        """ Add a list of additional tokens to the encoder.
+            The additional tokens are indexed starting from the last index of the
+            current vocabulary in the order of the `special_tokens` list.
+        """
+        if not special_tokens:
+            self.special_tokens = {}
+            self.special_tokens_decoder = {}
+            return
+        self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
+        self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
+        logger.info("Special tokens {}".format(self.special_tokens))
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token
+
+        while True:
+            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word)-1 and word[i+1] == second:
+                    new_word.append(first+second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def tokenize(self, text):
+        """ Tokenize a string. """
+        bpe_tokens = []
+        for token in re.findall(self.pat, text):
+            if sys.version_info[0] == 2:
+                token = ''.join(self.byte_encoder[ord(b)] for b in token)
+            else:
+                token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        """ Converts a sequence of tokens into ids using the vocab. """
+        ids = []
+        if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
+            if tokens in self.special_tokens:
+                return self.special_tokens[tokens]
+            else:
+                return self.encoder.get(tokens, 0)
+        for token in tokens:
+            if token in self.special_tokens:
+                ids.append(self.special_tokens[token])
+            else:
+                ids.append(self.encoder.get(token, 0))
+        if len(ids) > self.max_len:
+            logger.warning(
+                "Token indices sequence length is longer than the specified maximum "
+                " sequence length for this OpenAI GPT model ({} > {}). Running this"
+                " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
+            )
+        return ids
+
+    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
+        """Converts a sequence of ids in BPE tokens using the vocab."""
+        tokens = []
+        for i in ids:
+            if i in self.special_tokens_decoder:
+                if not skip_special_tokens:
+                    tokens.append(self.special_tokens_decoder[i])
+            else:
+                tokens.append(self.decoder[i])
+        return tokens
+
+    def encode(self, text):
+        return self.convert_tokens_to_ids(self.tokenize(text))
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
+        return text
+
+    def save_vocabulary(self, vocab_path):
+        """Save the tokenizer vocabulary and merge files to a directory."""
+        if not os.path.isdir(vocab_path):
+            logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
+            return
+        vocab_file = os.path.join(vocab_path, VOCAB_NAME)
+        merge_file = os.path.join(vocab_path, MERGES_NAME)
+        special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
+
+        with open(vocab_file, 'w', encoding='utf-8') as f:
+            f.write(json.dumps(self.encoder, ensure_ascii=False))
+
+        index = 0
+        with open(merge_file, "w", encoding="utf-8") as writer:
+            writer.write(u'#version: 0.2\n')
+            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+                if index != token_index:
+                    logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
+                                   " Please check that the tokenizer is not corrupted!".format(merge_file))
+                    index = token_index
+                writer.write(' '.join(bpe_tokens) + u'\n')
+                index += 1
+
+        index = len(self.encoder)
+        with open(special_tokens_file, 'w', encoding='utf-8') as writer:
+            for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
+                if index != token_index:
+                    logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
+                                   " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
+                    index = token_index
+                writer.write(token + u'\n')
+                index += 1
+
+        return vocab_file, merge_file, special_tokens_file
diff --git a/tokenization/text/wordpiece.py b/tokenization/text/wordpiece.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5ce2e27f31e216617ac8486f50204802b255c4f
--- /dev/null
+++ b/tokenization/text/wordpiece.py
@@ -0,0 +1,390 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py"""
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import collections
+import logging
+import os
+import unicodedata
+from io import open
+
+from ..file_utils import cached_path
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+    'bert-base-uncased': "pretrained/pytorch_pretrained_bert/bert-base-uncased-vocab.txt",
+    'bert-large-uncased': "pretrained/pytorch_pretrained_bert/bert-large-uncased-vocab.txt",
+    'bert-base-cased': "pretrained/pytorch_pretrained_bert/bert-base-cased-vocab.txt",
+    'bert-large-cased': "pretrained/pytorch_pretrained_bert/bert-large-cased-vocab.txt",
+    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
+    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
+    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+    'bert-base-uncased': 512,
+    'bert-large-uncased': 512,
+    'bert-base-cased': 512,
+    'bert-large-cased': 512,
+    'bert-base-multilingual-uncased': 512,
+    'bert-base-multilingual-cased': 512,
+    'bert-base-chinese': 512,
+}
+VOCAB_NAME = 'vocab.txt'
+
+
+def load_vocab(vocab_file):
+    """Loads a vocabulary file into a dictionary."""
+    vocab = collections.OrderedDict()
+    index = 0
+    with open(vocab_file, "r", encoding="utf-8") as reader:
+        while True:
+            token = reader.readline()
+            if not token:
+                break
+            token = token.strip()
+            vocab[token] = index
+            index += 1
+    return vocab
+
+
+def whitespace_tokenize(text):
+    """Runs basic whitespace cleaning and splitting on a piece of text."""
+    text = text.strip()
+    if not text:
+        return []
+    tokens = text.split()
+    return tokens
+
+
+class BertTokenizer(object):
+    """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
+
+    def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
+                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+        """Constructs a BertTokenizer.
+
+        Args:
+          vocab_file: Path to a one-wordpiece-per-line vocabulary file
+          do_lower_case: Whether to lower case the input
+                         Only has an effect when do_wordpiece_only=False
+          do_basic_tokenize: Whether to do basic tokenization before wordpiece.
+          max_len: An artificial maximum length to truncate tokenized sequences to;
+                         Effective maximum length is always the minimum of this
+                         value (if specified) and the underlying BERT model's
+                         sequence length.
+          never_split: List of tokens which will never be split during tokenization.
+                         Only has an effect when do_wordpiece_only=False
+        """
+        if not os.path.isfile(vocab_file):
+            raise ValueError(
+                "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
+                "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
+        self.vocab = load_vocab(vocab_file)
+        self.ids_to_tokens = collections.OrderedDict(
+            [(ids, tok) for tok, ids in self.vocab.items()])
+        self.do_basic_tokenize = do_basic_tokenize
+        if do_basic_tokenize:
+          self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
+                                                never_split=never_split)
+        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+        self.max_len = max_len if max_len is not None else int(1e12)
+
+    def tokenize(self, text):
+        if self.do_basic_tokenize:
+          split_tokens = []
+          for token in self.basic_tokenizer.tokenize(text):
+              for sub_token in self.wordpiece_tokenizer.tokenize(token):
+                  split_tokens.append(sub_token)
+        else:
+          split_tokens = self.wordpiece_tokenizer.tokenize(text)
+        return split_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        """Converts a sequence of tokens into ids using the vocab."""
+        ids = []
+        for token in tokens:
+            ids.append(self.vocab[token])
+        if len(ids) > self.max_len:
+            logger.warning(
+                "Token indices sequence length is longer than the specified maximum "
+                " sequence length for this BERT model ({} > {}). Running this"
+                " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
+            )
+        return ids
+
+    def convert_ids_to_tokens(self, ids):
+        """Converts a sequence of ids in wordpiece tokens using the vocab."""
+        tokens = []
+        for i in ids:
+            tokens.append(self.ids_to_tokens[i])
+        return tokens
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+        """
+        Instantiate a PreTrainedBertModel from a pre-trained model file.
+        Download and cache the pre-trained model file if needed.
+        """
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+            vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+        else:
+            vocab_file = pretrained_model_name_or_path
+        if os.path.isdir(vocab_file):
+            vocab_file = os.path.join(vocab_file, VOCAB_NAME)
+        # redirect to the cache, if necessary
+        try:
+            resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+        except EnvironmentError:
+            logger.error(
+                "Model name '{}' was not found in model name list ({}). "
+                "We assumed '{}' was a path or url but couldn't find any file "
+                "associated to this path or url.".format(
+                    pretrained_model_name_or_path,
+                    ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+                    vocab_file))
+            return None
+        if resolved_vocab_file == vocab_file:
+            logger.info("loading vocabulary file {}".format(vocab_file))
+        else:
+            logger.info("loading vocabulary file {} from cache at {}".format(
+                vocab_file, resolved_vocab_file))
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
+            # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+            # than the number of positional embeddings
+            max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
+            kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
+        # Instantiate tokenizer.
+        tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
+        return tokenizer
+
+
+class BasicTokenizer(object):
+    """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+    def __init__(self,
+                 do_lower_case=True,
+                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+        """Constructs a BasicTokenizer.
+
+        Args:
+          do_lower_case: Whether to lower case the input.
+        """
+        self.do_lower_case = do_lower_case
+        self.never_split = never_split
+
+    def tokenize(self, text):
+        """Tokenizes a piece of text."""
+        text = self._clean_text(text)
+        # This was added on November 1st, 2018 for the multilingual and Chinese
+        # models. This is also applied to the English models now, but it doesn't
+        # matter since the English models were not trained on any Chinese data
+        # and generally don't have any Chinese data in them (there are Chinese
+        # characters in the vocabulary because Wikipedia does have some Chinese
+        # words in the English Wikipedia.).
+        text = self._tokenize_chinese_chars(text)
+        orig_tokens = whitespace_tokenize(text)
+        split_tokens = []
+        for token in orig_tokens:
+            if self.do_lower_case and token not in self.never_split:
+                token = token.lower()
+                token = self._run_strip_accents(token)
+            split_tokens.extend(self._run_split_on_punc(token))
+
+        output_tokens = whitespace_tokenize(" ".join(split_tokens))
+        return output_tokens
+
+    def _run_strip_accents(self, text):
+        """Strips accents from a piece of text."""
+        text = unicodedata.normalize("NFD", text)
+        output = []
+        for char in text:
+            cat = unicodedata.category(char)
+            if cat == "Mn":
+                continue
+            output.append(char)
+        return "".join(output)
+
+    def _run_split_on_punc(self, text):
+        """Splits punctuation on a piece of text."""
+        if text in self.never_split:
+            return [text]
+        chars = list(text)
+        i = 0
+        start_new_word = True
+        output = []
+        while i < len(chars):
+            char = chars[i]
+            if _is_punctuation(char):
+                output.append([char])
+                start_new_word = True
+            else:
+                if start_new_word:
+                    output.append([])
+                start_new_word = False
+                output[-1].append(char)
+            i += 1
+
+        return ["".join(x) for x in output]
+
+    def _tokenize_chinese_chars(self, text):
+        """Adds whitespace around any CJK character."""
+        output = []
+        for char in text:
+            cp = ord(char)
+            if self._is_chinese_char(cp):
+                output.append(" ")
+                output.append(char)
+                output.append(" ")
+            else:
+                output.append(char)
+        return "".join(output)
+
+    def _is_chinese_char(self, cp):
+        """Checks whether CP is the codepoint of a CJK character."""
+        # This defines a "chinese character" as anything in the CJK Unicode block:
+        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+        #
+        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+        # despite its name. The modern Korean Hangul alphabet is a different block,
+        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+        # space-separated words, so they are not treated specially and handled
+        # like the all of the other languages.
+        if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
+                (cp >= 0x3400 and cp <= 0x4DBF) or  #
+                (cp >= 0x20000 and cp <= 0x2A6DF) or  #
+                (cp >= 0x2A700 and cp <= 0x2B73F) or  #
+                (cp >= 0x2B740 and cp <= 0x2B81F) or  #
+                (cp >= 0x2B820 and cp <= 0x2CEAF) or
+                (cp >= 0xF900 and cp <= 0xFAFF) or  #
+                (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
+            return True
+
+        return False
+
+    def _clean_text(self, text):
+        """Performs invalid character removal and whitespace cleanup on text."""
+        output = []
+        for char in text:
+            cp = ord(char)
+            if cp == 0 or cp == 0xfffd or _is_control(char):
+                continue
+            if _is_whitespace(char):
+                output.append(" ")
+            else:
+                output.append(char)
+        return "".join(output)
+
+
+class WordpieceTokenizer(object):
+    """Runs WordPiece tokenization."""
+
+    def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
+        self.vocab = vocab
+        self.unk_token = unk_token
+        self.max_input_chars_per_word = max_input_chars_per_word
+
+    def tokenize(self, text):
+        """Tokenizes a piece of text into its word pieces.
+
+        This uses a greedy longest-match-first algorithm to perform tokenization
+        using the given vocabulary.
+
+        For example:
+          input = "unaffable"
+          output = ["un", "##aff", "##able"]
+
+        Args:
+          text: A single token or whitespace separated tokens. This should have
+            already been passed through `BasicTokenizer`.
+
+        Returns:
+          A list of wordpiece tokens.
+        """
+
+        output_tokens = []
+        for token in whitespace_tokenize(text):
+            chars = list(token)
+            if len(chars) > self.max_input_chars_per_word:
+                output_tokens.append(self.unk_token)
+                continue
+
+            is_bad = False
+            start = 0
+            sub_tokens = []
+            while start < len(chars):
+                end = len(chars)
+                cur_substr = None
+                while start < end:
+                    substr = "".join(chars[start:end])
+                    if start > 0:
+                        substr = "##" + substr
+                    if substr in self.vocab:
+                        cur_substr = substr
+                        break
+                    end -= 1
+                if cur_substr is None:
+                    is_bad = True
+                    break
+                sub_tokens.append(cur_substr)
+                start = end
+
+            if is_bad:
+                output_tokens.append(self.unk_token)
+            else:
+                output_tokens.extend(sub_tokens)
+        return output_tokens
+
+
+def _is_whitespace(char):
+    """Checks whether `chars` is a whitespace character."""
+    # \t, \n, and \r are technically contorl characters but we treat them
+    # as whitespace since they are generally considered as such.
+    if char == " " or char == "\t" or char == "\n" or char == "\r":
+        return True
+    cat = unicodedata.category(char)
+    if cat == "Zs":
+        return True
+    return False
+
+
+def _is_control(char):
+    """Checks whether `chars` is a control character."""
+    # These are technically control characters but we count them as whitespace
+    # characters.
+    if char == "\t" or char == "\n" or char == "\r":
+        return False
+    cat = unicodedata.category(char)
+    if cat.startswith("C"):
+        return True
+    return False
+
+
+def _is_punctuation(char):
+    """Checks whether `chars` is a punctuation character."""
+    cp = ord(char)
+    # We treat all non-letter/number ASCII as punctuation.
+    # Characters such as "^", "$", and "`" are not in the Unicode
+    # Punctuation class but we treat them as punctuation anyways, for
+    # consistency.
+    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
+            (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
+        return True
+    cat = unicodedata.category(char)
+    if cat.startswith("P"):
+        return True
+    return False
diff --git a/training/model_io.py b/training/model_io.py
index 20b53b2a1ac3246ff17f2b14dcbbce0e387b69eb..1655ed20999230346b0dd20c7c547c52a38fb246 100644
--- a/training/model_io.py
+++ b/training/model_io.py
@@ -28,8 +28,12 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False):
         d += '_zero_dp_rank_{}'.format(dp_rank)
     return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank()))
 
-def get_checkpoint_tracker_filename(checkpoints_path):
-    return os.path.join(checkpoints_path, 'latest')
+def get_checkpoint_tracker_filename(checkpoints_path, old_checkpoint=False):
+    if old_checkpoint:
+        return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
+    else:
+        return os.path.join(checkpoints_path, 'latest')
+
 
 def save_checkpoint(iteration, model, optimizer,
                     lr_scheduler, args):
@@ -85,7 +89,7 @@ def save_ds_checkpoint_no_optim(model, save_dir, tag=None, client_state={}, save
 
 def get_checkpoint_iteration(args):
     # Read the tracker file and set the iteration.
-    tracker_filename = get_checkpoint_tracker_filename(args.load)
+    tracker_filename = get_checkpoint_tracker_filename(args.load, old_checkpoint=args.old_checkpoint)
     if not os.path.isfile(tracker_filename):
         print_rank_0('WARNING: could not find the metadata file {} '.format(
             tracker_filename))
@@ -126,7 +130,12 @@ def load_checkpoint(model, args):
         module = model.module
     else: # inference without deepspeed
         module = model
-        
+
+    # Process the checkpoint for GLM
+    if args.block_lm and args.old_checkpoint:
+        sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight']
+        del sd['module']['word_embeddings.weight']
+
     # only load module, other hyperparameters are just for recording.
     missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
     if len(unexpected_keys) > 0: