From ba522b5e0b4774e3818f4927318a9670ccd0369e Mon Sep 17 00:00:00 2001
From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn>
Date: Wed, 20 Oct 2021 19:36:39 +0800
Subject: [PATCH] Import GLM inference code

---
 arguments.py                           |   36 +-
 config/model_glm_roberta_large.sh      |   10 +
 inference_glm.py                       |  299 ++++++
 model/glm_model.py                     |    3 +-
 move_weights_glm.py                    |   38 +
 mpu/transformer.py                     |    4 +-
 scripts/generate_glm.sh                |   29 +
 tokenization/__init__.py               |   22 +-
 tokenization/file_utils.py             |  250 +++++
 tokenization/text/__init__.py          |    1 +
 tokenization/text/sp_tokenizer.py      |  150 +++
 tokenization/text/tokenization.py      | 1254 ++++++++++++++++++++++++
 tokenization/text/tokenization_gpt2.py |  310 ++++++
 tokenization/text/wordpiece.py         |  390 ++++++++
 14 files changed, 2787 insertions(+), 9 deletions(-)
 create mode 100644 config/model_glm_roberta_large.sh
 create mode 100644 inference_glm.py
 create mode 100644 move_weights_glm.py
 create mode 100644 scripts/generate_glm.sh
 create mode 100644 tokenization/file_utils.py
 create mode 100644 tokenization/text/__init__.py
 create mode 100644 tokenization/text/sp_tokenizer.py
 create mode 100644 tokenization/text/tokenization.py
 create mode 100644 tokenization/text/tokenization_gpt2.py
 create mode 100644 tokenization/text/wordpiece.py

diff --git a/arguments.py b/arguments.py
index 026e406..dbdeea6 100755
--- a/arguments.py
+++ b/arguments.py
@@ -157,6 +157,7 @@ def add_text_generate_args(parser):
     group.add_argument("--temperature", type=float, default=1.0)
     group.add_argument("--top_p", type=float, default=0.0)
     group.add_argument("--top_k", type=int, default=0)
+    group.add_argument("--num-beams", type=int, default=1)
     group.add_argument("--out-seq-length", type=int, default=256)
     group.add_argument('--input-source', type=str, default='interactive',
                        help='what input mode to use, interactive or path')
@@ -214,12 +215,44 @@ def add_tokenization_args(parser):
 
     group = parser.add_argument_group('Tokenization', 'tokenization configurations')
     group.add_argument('--tokenizer-type', type=str, default='fake', help='type name of tokenizer')
-
+    group.add_argument('--tokenizer-model-type', type=str,
+                       default=None,
+                       help="Model type to use for sentencepiece tokenization \
+                           (one of ['bpe', 'char', 'unigram', 'word']) or \
+                           bert vocab to use for BertWordPieceTokenizer (one of \
+                           ['bert-large-uncased', 'bert-large-cased', etc.])")
     group.add_argument('--img-tokenizer-path', type=str, default=None,
                        help='The checkpoint file path of image tokenizer.')
     return parser
 
 
+def add_glm_args(parser):
+    """Arguments for GLM"""
+    group = parser.add_argument_group('GLM', 'GLM Configurations')
+    group.add_argument('--block-lm', action='store_true', help="whether use the BlockLM pre-training")
+    group.add_argument('--masked-lm', action='store_true', help='whether to use the mlm objective')
+    group.add_argument('--bert-prob', type=float, default=0.5)
+    group.add_argument('--gpt-infill-prob', type=float, default=0.5)
+    group.add_argument('--gpt-min-ratio', type=float, default=0.5)
+    group.add_argument('--gap-sentence-prob', type=float, default=0.0)
+    group.add_argument('--gap-sentence-ratio', type=float, default=0.15)
+    group.add_argument('--avg-block-length', type=int, default=3)
+    group.add_argument('--short-seq-prob', type=float, default=0.0)
+    group.add_argument('--single-span-prob', type=float, default=0.0)
+    group.add_argument('--task-mask', action='store_true', help="Use different mask for generation and blank filling")
+    group.add_argument('--no-shuffle-block', action='store_true', help="not shuffle the blocks when filling the blank")
+    group.add_argument('--no-block-position', action='store_true',
+                       help='Use (rough) absolute positions instead of block positions')
+    group.add_argument('--sentinel-token', action='store_true',
+                       help="Use sentinel (mask) tokens to replace 2d position encoding")
+    group.add_argument('--block-mask-prob', type=float, default=0.0)
+    group.add_argument('--context-mask-ratio', type=float, default=0.0)
+    group.add_argument('--random-position', action='store_true',
+                       help="Use random start position to cover all the position embeddings")
+    group.add_argument('--cloze-eval', action='store_true', help='Evaluation dataset with cloze task')
+    return parser
+
+
     
 def get_args(args_list=None):
     """Parse all the args."""
@@ -232,6 +265,7 @@ def get_args(args_list=None):
     parser = add_tokenization_args(parser)
     parser = add_text_generate_args(parser)
     parser = add_generation_api_args(parser)
+    parser = add_glm_args(parser)
 
     # Include DeepSpeed configuration arguments
     parser = deepspeed.add_config_arguments(parser)
diff --git a/config/model_glm_roberta_large.sh b/config/model_glm_roberta_large.sh
new file mode 100644
index 0000000..6fb7216
--- /dev/null
+++ b/config/model_glm_roberta_large.sh
@@ -0,0 +1,10 @@
+MODEL_TYPE="blocklm-roberta-large"
+MODEL_ARGS="--block-lm \
+            --cloze-eval \
+            --num-layers 24 \
+            --hidden-size 1024 \
+            --num-attention-heads 16 \
+            --max-sequence-length 513 \
+            --tokenizer-model-type roberta \
+            --tokenizer-type GPT2BPETokenizer \
+            --load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank"
\ No newline at end of file
diff --git a/inference_glm.py b/inference_glm.py
new file mode 100644
index 0000000..b88c2a7
--- /dev/null
+++ b/inference_glm.py
@@ -0,0 +1,299 @@
+# -*- encoding: utf-8 -*-
+'''
+@File    :   inference_cogview.py
+@Time    :   2021/10/09 19:41:58
+@Author  :   Ming Ding
+@Contact :   dm18@mail.tsinghua.edu.cn
+'''
+
+# here put the import lib
+import os
+import sys
+import random
+import time
+from datetime import datetime
+import torch
+import torch.nn.functional as F
+
+import mpu
+from arguments import get_args
+from model.glm_model import GLMModel
+from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
+from tokenization import get_tokenizer
+from generation.sampling_strategies import BaseStrategy
+from generation.autoregressive_sampling import filling_sequence
+from generation.utils import timed_name, save_multiple_images, generate_continually
+
+
+def read_context(tokenizer, args, output=None):
+    terminate_runs, skip_run = 0, 0
+    if mpu.get_model_parallel_rank() == 0:
+        while True:
+            raw_text = input("\nContext prompt (stop to exit) >>> ")
+            if not raw_text:
+                print('Prompt should not be empty!')
+                continue
+            if raw_text == "stop":
+                terminate_runs = 1
+                break
+            generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
+            if args.block_lm and 'MASK]' not in raw_text:
+                raw_text += ' ' + generation_mask
+            if output is not None:
+                output.write(raw_text)
+            context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
+            if args.block_lm:
+                context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
+                if not raw_text.endswith('MASK]'):
+                    context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
+            context_length = len(context_tokens)
+
+            if context_length >= args.max_sequence_length:
+                print("\nContext length", context_length,
+                      "\nPlease give smaller context than the window length!")
+                continue
+            break
+    else:
+        context_length = 0
+
+    terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
+    torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    terminate_runs = terminate_runs_tensor[0].item()
+
+    if terminate_runs == 1:
+        return terminate_runs, None, None, None
+
+    context_length_tensor = torch.cuda.LongTensor([context_length])
+
+    torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    context_length = context_length_tensor[0].item()
+    if mpu.get_model_parallel_rank() == 0:
+        context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
+    else:
+        context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
+    torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
+                                group=mpu.get_model_parallel_group())
+    if mpu.get_model_parallel_rank() != 0:
+        raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
+    return terminate_runs, raw_text, context_tokens_tensor, context_length
+
+
+def get_batch(context_tokens, device, args):
+    tokens = context_tokens
+    tokens = tokens.view(1, -1).contiguous()
+    tokens = tokens.to(device)
+
+    # Get the masks and postition ids.
+    if args.block_lm:
+        attention_mask = torch.ones(1, 1, tokens.size(1), tokens.size(1), device=device, dtype=torch.long)
+        if args.fp16:
+            attention_mask = attention_mask.half()
+        position_ids = torch.arange(tokens.size(1), device=device, dtype=torch.long)
+        if not args.no_block_position:
+            block_position_ids = torch.zeros(tokens.size(1), device=device, dtype=torch.long)
+            position_ids = torch.stack((position_ids, block_position_ids), dim=0)
+        position_ids = position_ids.unsqueeze(0)
+    else:
+        raise NotImplementedError
+
+    return tokens, attention_mask, position_ids
+
+
+def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
+    # This function has been mostly taken from huggingface conversational ai code at
+    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
+
+    if top_k > 0:
+        # Remove all tokens with a probability less than the last token of the top-k
+        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p > 0.0:
+        # convert to 1D
+        logits = logits.view(logits.size()[1]).contiguous()
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+        # Remove tokens with cumulative probability above the threshold
+        sorted_indices_to_remove = cumulative_probs > top_p
+        # Shift the indices to the right to keep also the first token above the threshold
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = 0
+        indices_to_remove = sorted_indices[sorted_indices_to_remove]
+        logits[indices_to_remove] = filter_value
+        # going back to 2D
+        logits = logits.view(1, -1).contiguous()
+
+    return logits
+
+
+def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None):
+    if not args.block_lm:
+        context_tokens, attention_mask, position_ids = get_batch(context_tokens, device, args)
+        tokens = torch.empty((args.num_beams, 0), device=context_tokens.device, dtype=torch.long)
+    else:
+        tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id)
+    counter = 0
+    if mems is None:
+        mems = []
+    if end_tokens is None:
+        end_tokens = [args.eod_token]
+    if args.num_beams > 1:
+        beam_scorer = BeamSearchScorer(
+            batch_size=1,
+            max_length=args.out_seq_length,
+            num_beams=args.num_beams,
+            device=context_tokens.device,
+            length_penalty=args.length_penalty,
+            do_early_stopping=False,
+        )
+        beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device)
+    last_beam_num = 1
+    while counter < args.out_seq_length:
+        if counter == 0 and not args.block_lm:
+            next_token_logits, *mems = model(context_tokens, position_ids, attention_mask, *mems)
+        else:
+            if args.block_lm:
+                if args.no_block_position:
+                    position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter)
+                else:
+                    position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
+                    position_ids[:, 0] = context_length
+                    position_ids[:, 1] = counter + 1
+                attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long)
+            else:
+                position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1)
+                attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1,
+                                                         device=context_tokens.device, dtype=torch.float)
+            last_token = tokens[:, -1:]
+            next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems)
+        next_token_logits = next_token_logits[:, -1]
+        if args.num_beams > 1:
+            next_token_scores = F.log_softmax(next_token_logits, dim=-1)
+            next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
+            vocab_size = next_token_scores.shape[-1]
+            next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size)
+
+            probs = F.softmax(next_token_scores, dim=-1)
+            next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams)
+            next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
+            next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
+            next_tokens = torch.gather(next_tokens, -1, _indices)
+
+            next_indices = next_tokens // vocab_size
+            next_tokens = next_tokens % vocab_size
+            # stateless
+            tokens = tokens.expand((args.num_beams, -1))
+            beam_outputs = beam_scorer.process(
+                tokens,
+                next_token_scores,
+                next_tokens,
+                next_indices,
+                eos_token_id=end_tokens,
+                mems=mems
+            )
+            beam_scores = beam_outputs["next_beam_scores"]
+            beam_next_tokens = beam_outputs["next_beam_tokens"]
+            beam_idx = beam_outputs["next_beam_indices"]
+            beam_next_tokens = beam_next_tokens.unsqueeze(-1)
+            tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1)
+            mems = [mem[beam_idx] for mem in mems] if mems else None
+            if beam_scorer.is_done:
+                break
+            last_beam_num = args.num_beams
+        else:
+            next_token_logits /= args.temperature
+            next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p)
+            log_probs = F.softmax(next_token_logits, dim=-1)
+            prev = torch.multinomial(log_probs, num_samples=1)[0]
+            is_end = prev.item() in end_tokens
+            if is_end:
+                break
+            prev = prev.view(1, 1)
+            tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1)
+        counter += 1
+        if not args.block_lm and mpu.get_model_parallel_rank() == 0 and counter % 16 == 0:
+            output_tokens_list = tokens.view(-1).contiguous()
+            decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
+            if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end):
+                os.system('clear')
+                trim_decode_tokens = decode_tokens
+                print(trim_decode_tokens, flush=True)
+    if args.num_beams > 1:
+        tokens, mems = beam_scorer.finalize(tokens, beam_scores, next_tokens, next_indices, eos_token_id=args.eod_token,
+                                            mems=mems)
+    return torch.cat((context_tokens, tokens), dim=1), mems
+
+
+def generate_samples(model, tokenizer, args, device):
+    model.eval()
+    output_path = "./samples"
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
+    output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
+    with torch.no_grad(), open(output_path, "w") as output:
+        while True:
+            torch.distributed.barrier(group=mpu.get_model_parallel_group())
+
+            terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
+            if terminate_runs == 1:
+                return
+            start_time = time.time()
+            if args.block_lm:
+                mems = []
+                tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args)
+                mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
+                mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
+                end_tokens = [tokenizer.get_command('eop').Id, args.eod_token]
+                mask_positions = []
+                for token in mask_tokens:
+                    mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
+                mask_positions.sort()
+                if args.no_block_position:
+                    for mask_position in mask_positions:
+                        position_ids[0, mask_position + 1:] += args.out_seq_length
+                _, *mems = model(tokens, position_ids, attention_mask, *mems)
+                for mask_position in mask_positions:
+                    if args.no_block_position:
+                        position = position_ids[0, mask_position].item()
+                    else:
+                        position = mask_position
+                    tokens, mems = sample_sequence(model, tokenizer, tokens, position,
+                                                   args, device, mems=mems, end_tokens=end_tokens)
+            else:
+                tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device)
+            output_tokens_list = tokens.view(-1).contiguous()
+            if mpu.get_model_parallel_rank() == 0:
+                os.system('clear')
+                print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
+                print("\nContext:", raw_text, flush=True)
+                decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
+                trim_decode_tokens = decode_tokens
+                print("\nGLM:", trim_decode_tokens, flush=True)
+                output.write(trim_decode_tokens + "\n")
+
+            torch.distributed.barrier(group=mpu.get_model_parallel_group())
+
+
+def main(args):
+    initialize_distributed(args)
+    tokenizer = prepare_tokenizer(args)
+    args.eod_token = tokenizer.get_command('eos').Id
+    # build model
+    model = GLMModel(args)
+    if args.fp16:
+        model = model.half()
+    model = model.to(args.device)
+    load_checkpoint(model, args)
+    set_random_seed(args.seed)
+    model.eval()
+    generate_samples(model, tokenizer, args, torch.cuda.current_device())
+
+
+if __name__ == "__main__":
+    args = get_args()
+
+    with torch.no_grad():
+        main(args)
diff --git a/model/glm_model.py b/model/glm_model.py
index 0edf9b9..96502d1 100644
--- a/model/glm_model.py
+++ b/model/glm_model.py
@@ -2,9 +2,10 @@ import torch
 import torch.nn as nn
 
 from .base_model import BaseModel
+from .cached_autoregressive_model import CachedAutoregressiveModel
 
 
-class GLMModel(BaseModel):
+class GLMModel(CachedAutoregressiveModel):
     def __init__(self, args, transformer=None):
         super().__init__(args, transformer=transformer)
         self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size)
diff --git a/move_weights_glm.py b/move_weights_glm.py
new file mode 100644
index 0000000..6d3755d
--- /dev/null
+++ b/move_weights_glm.py
@@ -0,0 +1,38 @@
+import sys
+import os
+import torch
+import copy
+
+checkpoint = sys.argv[1]
+target_path = sys.argv[2]
+
+assert os.path.isdir(checkpoint)
+iteration_file = os.path.join(checkpoint, 'latest_checkpointed_iteration.txt')
+if os.path.exists(iteration_file):
+    with open(iteration_file) as fin:
+        iteration = int(fin.read().strip())
+    checkpoint = os.path.join(checkpoint, str(iteration))
+else:
+    iteration = None
+
+os.makedirs(target_path, exist_ok=True)
+if iteration is not None:
+    with open(os.path.join(target_path, "latest"), "w") as output:
+        output.write(str(iteration))
+    target_path = os.path.join(target_path, str(iteration))
+    os.makedirs(target_path, exist_ok=True)
+
+
+filenames = os.listdir(checkpoint)
+filenames = [filename for filename in filenames if filename.startswith("mp_rank_")]
+filenames = sorted(filenames,
+                   key=lambda x: int(x.split('_')[2]))
+filenames = [os.path.join(checkpoint, x) for x in filenames]
+
+for filename in filenames:
+    data = torch.load(filename)
+    state_dict = data['module']
+    state_dict['transformer.word_embeddings.weight'] = state_dict['word_embeddings.weight']
+    del state_dict['word_embeddings.weight']
+    # print(f"Target path: {os.path.join(target_path, os.path.basename(filename))}")
+    torch.save(data, os.path.join(target_path, os.path.basename(filename)))
diff --git a/mpu/transformer.py b/mpu/transformer.py
index e04485a..9751672 100755
--- a/mpu/transformer.py
+++ b/mpu/transformer.py
@@ -320,8 +320,6 @@ class BaseTransformer(torch.nn.Module):
         # sanity check 
         assert len(input_ids.shape) == 2 
         batch_size, query_length = input_ids.shape
-        assert len(position_ids.shape) <= 2
-        assert position_ids.shape[-1] == query_length
         assert len(attention_mask.shape) == 2 or \
             len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
 
@@ -334,6 +332,8 @@ class BaseTransformer(torch.nn.Module):
         if 'position_embedding_forward' in self.hooks:
             position_embeddings = self.hooks['position_embedding_forward'](position_ids, *other_tensors)
         else:
+            assert len(position_ids.shape) <= 2
+            assert position_ids.shape[-1] == query_length
             position_embeddings = self.position_embeddings(position_ids)    
         hidden_states = hidden_states + position_embeddings
         hidden_states = self.embedding_dropout(hidden_states)
diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh
new file mode 100644
index 0000000..af5a59a
--- /dev/null
+++ b/scripts/generate_glm.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+CHECKPOINT_PATH=./checkpoints
+
+source $1
+
+MPSIZE=1
+MAXSEQLEN=512
+MASTER_PORT=$(shuf -n 1 -i 10000-65535)
+
+#SAMPLING ARGS
+TEMP=0.9
+#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p
+TOPK=40
+TOPP=0
+
+script_path=$(realpath $0)
+script_dir=$(dirname $script_path)
+
+config_json="$script_dir/ds_config.json"
+
+python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_glm.py \
+       --mode inference \
+       --model-parallel-size $MPSIZE \
+       $MODEL_ARGS \
+       --fp16 \
+       --out-seq-length $MAXSEQLEN \
+       --temperature $TEMP \
+       --top_k $TOPK \
+       --top_p $TOPP
diff --git a/tokenization/__init__.py b/tokenization/__init__.py
index a6b481d..f465ec4 100644
--- a/tokenization/__init__.py
+++ b/tokenization/__init__.py
@@ -13,24 +13,36 @@ import math
 import random
 import torch
 
+
 def get_tokenizer(args=None):
+    kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask,
+              "add_decoder_mask": args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0}
     if not hasattr(get_tokenizer, 'tokenizer'):
         # the first time to load the tokenizer
         if args.tokenizer_type == 'cogview':
             from .cogview import UnifiedTokenizer
             get_tokenizer.tokenizer = UnifiedTokenizer(
-                args.img_tokenizer_path, 
+                args.img_tokenizer_path,
                 device=torch.cuda.current_device()
-                )
-        elif args.tokenizer_type == 'glm':
-            pass
+            )
+        elif args.tokenizer_type == "BertWordPieceTokenizer":
+            from .text import BertWordPieceTokenizer
+            get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs)
+        elif args.tokenizer_type == "GPT2BPETokenizer":
+            from .text import GPT2BPETokenizer
+            get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
+        elif args.tokenizer_type == "ChineseSPTokenizer":
+            from .text import ChineseSPTokenizer
+            get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs)
         else:
             assert args.vocab_size > 0
             get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
     return get_tokenizer.tokenizer
 
+
 class FakeTokenizer(object):
     def __init__(self, num_tokens):
         self.num_tokens = num_tokens
+
     def __len__(self):
-        return self.num_tokens
\ No newline at end of file
+        return self.num_tokens
diff --git a/tokenization/file_utils.py b/tokenization/file_utils.py
new file mode 100644
index 0000000..e4be142
--- /dev/null
+++ b/tokenization/file_utils.py
@@ -0,0 +1,250 @@
+# This file is provided as is from:
+#   https://github.com/huggingface/pytorch-pretrained-BERT
+# Please refer to their repository for copyright.
+
+"""
+Utilities for working with the local dataset cache.
+This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
+Copyright by the AllenNLP authors.
+"""
+from __future__ import (absolute_import, division, print_function, unicode_literals)
+
+import json
+import logging
+import os
+import shutil
+import tempfile
+from functools import wraps
+from hashlib import sha256
+import sys
+from io import open
+
+import boto3
+import requests
+from botocore.exceptions import ClientError
+from tqdm import tqdm
+
+from urllib.parse import urlparse
+
+try:
+    from pathlib import Path
+    PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
+                                                   Path.home() / '.pytorch_pretrained_bert'))
+except (AttributeError, ImportError):
+    PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
+                                              os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+def url_to_filename(url, etag=None):
+    """
+    Convert `url` into a hashed filename in a repeatable way.
+    If `etag` is specified, append its hash to the url's, delimited
+    by a period.
+    """
+    url_bytes = url.encode('utf-8')
+    url_hash = sha256(url_bytes)
+    filename = url_hash.hexdigest()
+
+    if etag:
+        etag_bytes = etag.encode('utf-8')
+        etag_hash = sha256(etag_bytes)
+        filename += '.' + etag_hash.hexdigest()
+
+    return filename
+
+
+def filename_to_url(filename, cache_dir=None):
+    """
+    Return the url and etag (which may be ``None``) stored for `filename`.
+    Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
+    """
+    if cache_dir is None:
+        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+        cache_dir = str(cache_dir)
+
+    cache_path = os.path.join(cache_dir, filename)
+    if not os.path.exists(cache_path):
+        raise EnvironmentError("file {} not found".format(cache_path))
+
+    meta_path = cache_path + '.json'
+    if not os.path.exists(meta_path):
+        raise EnvironmentError("file {} not found".format(meta_path))
+
+    with open(meta_path, encoding="utf-8") as meta_file:
+        metadata = json.load(meta_file)
+    url = metadata['url']
+    etag = metadata['etag']
+
+    return url, etag
+
+
+def cached_path(url_or_filename, cache_dir=None):
+    """
+    Given something that might be a URL (or might be a local path),
+    determine which. If it's a URL, download the file and cache it, and
+    return the path to the cached file. If it's already a local path,
+    make sure the file exists and then return the path.
+    """
+    if cache_dir is None:
+        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+    if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
+        url_or_filename = str(url_or_filename)
+    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+        cache_dir = str(cache_dir)
+
+    parsed = urlparse(url_or_filename)
+
+    if parsed.scheme in ('http', 'https', 's3'):
+        # URL, so get it from the cache (downloading if necessary)
+        return get_from_cache(url_or_filename, cache_dir)
+    elif os.path.exists(url_or_filename):
+        # File, and it exists.
+        return url_or_filename
+    elif parsed.scheme == '':
+        # File, but it doesn't exist.
+        raise EnvironmentError("file {} not found".format(url_or_filename))
+    else:
+        # Something unknown
+        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
+
+
+def split_s3_path(url):
+    """Split a full s3 path into the bucket name and path."""
+    parsed = urlparse(url)
+    if not parsed.netloc or not parsed.path:
+        raise ValueError("bad s3 path {}".format(url))
+    bucket_name = parsed.netloc
+    s3_path = parsed.path
+    # Remove '/' at beginning of path.
+    if s3_path.startswith("/"):
+        s3_path = s3_path[1:]
+    return bucket_name, s3_path
+
+
+def s3_request(func):
+    """
+    Wrapper function for s3 requests in order to create more helpful error
+    messages.
+    """
+
+    @wraps(func)
+    def wrapper(url, *args, **kwargs):
+        try:
+            return func(url, *args, **kwargs)
+        except ClientError as exc:
+            if int(exc.response["Error"]["Code"]) == 404:
+                raise EnvironmentError("file {} not found".format(url))
+            else:
+                raise
+
+    return wrapper
+
+
+@s3_request
+def s3_etag(url):
+    """Check ETag on S3 object."""
+    s3_resource = boto3.resource("s3")
+    bucket_name, s3_path = split_s3_path(url)
+    s3_object = s3_resource.Object(bucket_name, s3_path)
+    return s3_object.e_tag
+
+
+@s3_request
+def s3_get(url, temp_file):
+    """Pull a file directly from S3."""
+    s3_resource = boto3.resource("s3")
+    bucket_name, s3_path = split_s3_path(url)
+    s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
+
+
+def http_get(url, temp_file):
+    req = requests.get(url, stream=True)
+    content_length = req.headers.get('Content-Length')
+    total = int(content_length) if content_length is not None else None
+    progress = tqdm(unit="B", total=total)
+    for chunk in req.iter_content(chunk_size=1024):
+        if chunk: # filter out keep-alive new chunks
+            progress.update(len(chunk))
+            temp_file.write(chunk)
+    progress.close()
+
+
+def get_from_cache(url, cache_dir=None):
+    """
+    Given a URL, look for the corresponding dataset in the local cache.
+    If it's not there, download it. Then return the path to the cached file.
+    """
+    if cache_dir is None:
+        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
+    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
+        cache_dir = str(cache_dir)
+
+    if not os.path.exists(cache_dir):
+        os.makedirs(cache_dir)
+
+    # Get eTag to add to filename, if it exists.
+    if url.startswith("s3://"):
+        etag = s3_etag(url)
+    else:
+        response = requests.head(url, allow_redirects=True)
+        if response.status_code != 200:
+            raise IOError("HEAD request failed for url {} with status code {}"
+                          .format(url, response.status_code))
+        etag = response.headers.get("ETag")
+
+    filename = url_to_filename(url, etag)
+
+    # get cache path to put the file
+    cache_path = os.path.join(cache_dir, filename)
+
+    if not os.path.exists(cache_path):
+        # Download to temporary file, then copy to cache dir once finished.
+        # Otherwise you get corrupt cache entries if the download gets interrupted.
+        with tempfile.NamedTemporaryFile() as temp_file:
+            logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
+
+            # GET file object
+            if url.startswith("s3://"):
+                s3_get(url, temp_file)
+            else:
+                http_get(url, temp_file)
+
+            # we are copying the file before closing it, so flush to avoid truncation
+            temp_file.flush()
+            # shutil.copyfileobj() starts at the current position, so go to the start
+            temp_file.seek(0)
+
+            logger.info("copying %s to cache at %s", temp_file.name, cache_path)
+            with open(cache_path, 'wb') as cache_file:
+                shutil.copyfileobj(temp_file, cache_file)
+
+            logger.info("creating metadata file for %s", cache_path)
+            meta = {'url': url, 'etag': etag}
+            meta_path = cache_path + '.json'
+            with open(meta_path, 'w', encoding="utf-8") as meta_file:
+                json.dump(meta, meta_file)
+
+            logger.info("removing temp file %s", temp_file.name)
+
+    return cache_path
+
+
+def read_set_from_file(filename):
+    '''
+    Extract a de-duped collection (set) of text from a file.
+    Expected file format is one item per line.
+    '''
+    collection = set()
+    with open(filename, 'r', encoding='utf-8') as file_:
+        for line in file_:
+            collection.add(line.rstrip())
+    return collection
+
+
+def get_file_extension(path, dot=True, lower=True):
+    ext = os.path.splitext(path)[1]
+    ext = ext if dot else ext[1:]
+    return ext.lower() if lower else ext
diff --git a/tokenization/text/__init__.py b/tokenization/text/__init__.py
new file mode 100644
index 0000000..d191be6
--- /dev/null
+++ b/tokenization/text/__init__.py
@@ -0,0 +1 @@
+from .tokenization import BertWordPieceTokenizer, GPT2BPETokenizer, ChineseSPTokenizer
diff --git a/tokenization/text/sp_tokenizer.py b/tokenization/text/sp_tokenizer.py
new file mode 100644
index 0000000..7b6430e
--- /dev/null
+++ b/tokenization/text/sp_tokenizer.py
@@ -0,0 +1,150 @@
+"""
+from https://github.com/openai/gpt-2/, changed for chinese
+"""
+import json
+import os
+import sentencepiece as spm
+
+"""
+SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation 
+systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements 
+subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the 
+extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end 
+system that does not depend on language-specific pre/postprocessing.
+https://github.com/google/sentencepiece
+
+pip install sentencepiece
+
+or  git clone https://github.com/google/sentencepiece.git
+python setup.py install
+
+"""
+PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model"
+
+
+def get_pairs(word):
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+class Encoder:
+    def __init__(self, encoder, bpe_merges):
+        self.encoder = encoder
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+        self.cache = {}
+        self.max_len = 0
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token)
+        pairs = get_pairs(word)
+        if not pairs:
+            return token
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        return [self.encoder.get(token, 1) for token in self.tokenize(text)]
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        return text
+
+    def tokenize(self, text):
+        bpe_tokens = []
+        bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
+        return bpe_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        return [self.encoder.get(token, 1) for token in tokens]
+
+
+class Encoder_SP:
+    def __init__(self, model_path):
+        self.sp = spm.SentencePieceProcessor()
+        self.sp.Load(model_path)
+
+    def encode(self, text):
+        """
+        text="...."
+        """
+        return self.sp.EncodeAsIds(text)
+
+    def decode(self, tokens):
+        """
+        tokens=[x1,x2,...]
+        """
+        text = [int(token) for token in tokens]
+        # print(text)
+        return self.sp.DecodeIds(text)
+
+    def tokenize(self, text):
+        return self.sp.EncodeAsPieces(text)
+
+    def convert_tokens_to_ids(self, tokens):
+        return [self.sp.PieceToId(token) for token in tokens]
+
+    def convert_token_to_id(self, token):
+        return self.sp.PieceToId(token)
+
+    def convert_id_to_token(self, idx):
+        return self.sp.IdToPiece(idx)
+
+
+def get_encoder(encoder_file, bpe_file):
+    # 以下是为了同一个函数入兼容sentencepiece
+    filepath, filename = os.path.split(encoder_file)
+    shotname, extension = os.path.splitext(filename)
+
+    if (".model" == extension) and (bpe_file == ""):
+        return Encoder_SP(encoder_file)
+    else:
+        with open(encoder_file, 'r', encoding="utf-8") as f:
+            encoder = json.load(f)
+        with open(bpe_file, 'r', encoding="utf-8") as f:
+            bpe_data = f.read()
+        bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
+        return Encoder(
+            encoder=encoder,
+            bpe_merges=bpe_merges,
+        )
+
+
+def from_pretrained():
+    return get_encoder(PRETRAINED_MODEL_FILE, "")
diff --git a/tokenization/text/tokenization.py b/tokenization/text/tokenization.py
new file mode 100644
index 0000000..51d2009
--- /dev/null
+++ b/tokenization/text/tokenization.py
@@ -0,0 +1,1254 @@
+# coding=utf-8
+# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)"""
+from collections import namedtuple
+import random
+import os
+import csv
+import torch
+import itertools
+
+import nltk
+from nltk import tokenize as nltk_tokenize
+import sentencepiece as spm
+
+from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
+
+from .tokenization_gpt2 import GPT2Tokenizer
+from . import sp_tokenizer
+import regex as re
+
+
+def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type=None, pad_token=0,
+                   character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
+    """
+    Helper function to instantiate a tokenizer given common combinations of options.
+    """
+    tokenizer_class = tokenizer_type
+    if isinstance(tokenizer_class, str):
+        tokenizer_class = eval(tokenizer_class)
+    if tokenizer_class is BertWordPieceTokenizer:
+        return BertWordPieceTokenizer(model_type, **kwargs)
+    elif tokenizer_class is GPT2BPETokenizer:
+        if model_type is None:
+            model_type = 'gpt2'
+        return GPT2BPETokenizer(model_type, **kwargs)
+    elif tokenizer_class is ChineseSPTokenizer:
+        return ChineseSPTokenizer(**kwargs)
+    text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type,
+                                     pad_token=pad_token, character_coverage=character_coverage)
+    return Tokenizer(text_tokenizer, command_tokens, type_tokens)
+
+
+class Tokenization(object):
+    """
+    Tokenization object to hold tokenization, (processed text),and original
+    text. Can hold tokenization as Ids or tokens.
+
+    It also holds command tokens (pad, unk, etc.) for the tokenization.
+    This allows functions to pad/operate on tokenizations without having
+    access to the full tokenizer, just the tokenization.
+
+    Several standard array operations are implemented (insert, append, extend).
+    """
+
+    def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True):
+        self.tokenization = tokenization
+        self.text = text
+        if self.text is None:
+            self.text = self.tokenization
+        self.original_text = original_text
+        if self.original_text is None:
+            self.original_text = self.text
+        self.command_tokens = command_tokens
+        self.asIds = asIds
+        self.parse_command_tokens()
+
+    def set_command_tokens(self, command_tokens):
+        self.command_tokens = command_tokens
+        return self.parse_command_tokens()
+
+    def parse_command_tokens(self):
+        if self.command_tokens is None:
+            return
+        for command_token in self.command_tokens:
+            if self.asIds:
+                setattr(self, command_token.name, command_token.Id)
+            else:
+                setattr(self, command_token.name, command_token.token)
+
+    def __getitem__(self, index):
+        return self.tokenization[index]
+
+    def __len__(self):
+        return len(self.tokenization)
+
+    def insert(self, idx, other):
+        if isinstance(other, (CommandToken, TypeToken)):
+            self.tokenization.insert(idx, other.Id)
+            if idx == 0:
+                self.text = other.token + self.text
+                self.original_text = other.token + self.original_text
+            elif idx == len(self.tokenization) - 1:
+                self.text += other.token
+                self.original_text += other.token
+        elif isinstance(other, Tokenization):
+            self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
+        else:
+            self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
+
+    def append(self, other):
+        if isinstance(other, (CommandToken, TypeToken)):
+            self.tokenization.append(other.Id)
+            self.text += other.token
+            self.original_text += other.token
+        elif isinstance(other, Tokenization):
+            self.tokenization.extend(other.tokenization)
+            self.text += other.text
+            self.original_text += other.original_text
+        else:
+            self.tokenization.append(other)
+        return self
+
+    def extend(self, other):
+        if isinstance(other, (CommandToken, TypeToken)):
+            self.tokenization.append(other.Id)
+            self.text += other.token
+            self.original_text += other.token
+        elif isinstance(other, list) and isinstance(other[0], (CommandToken, TypeToken)):
+            self.tokenization.extend([o.Id for o in other])
+            self.text += [o.token for o in other]
+            self.original_text += [o.token for o in other]
+        elif isinstance(other, Tokenization):
+            self.tokenization.extend(other.tokenization)
+            self.text += other.text
+            self.original_text += other.original_text
+        else:
+            self.tokenization.extend(other)
+        return self
+
+
+"""define some default command tokens for the tokenizer to use"""
+token_format = "<{0}>"
+
+COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
+
+
+def prep_command_tokens(tokenlist, token_format=token_format):
+    return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
+
+
+class CommandToken(object):
+    def __init__(self, name, token, Id, lstrip=False, rstrip=False):
+        self.name = name
+        self.token = token
+        self.Id = Id
+        self.lstrip = lstrip
+        self.rstrip = rstrip
+
+    def __str__(self):
+        return str(COMMAND_TUPLE(self.name, self.token, self.Id))
+
+
+DEFAULT_COMMAND_TOKENS = [
+    ('pad', 0),
+    ('eos', 1),
+    ('bos', 2),
+    ('unk', 3),
+    ('sep', 4),
+    ('L2R', 5),
+    ('ENC', 6),
+    ('MASK', 7),
+]
+DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
+
+"""define some default type tokens for bert training"""
+
+TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
+
+
+def prep_type_tokens(tokenlist, token_format=token_format):
+    return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
+
+
+class TypeToken(object):
+    def __init__(self, name, token, Id):
+        self.name = name
+        self.token = token
+        self.Id = Id
+
+    def __str__(self):
+        return str(TYPE_TUPLE(self.name, self.token, self.Id))
+
+
+DEFAULT_TYPE_TOKENS = [
+    ('function', 0),
+    ('command', 1),
+    ('str0', 2),
+    ('str1', 3),
+    ('str2', 4),
+    ('embedding0', 5),
+    ('embedding1', 6),
+    ('embedding2', 7),
+    ('arg0', 8),
+    ('arg1', 9),
+    ('arg2', 10),
+]
+DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
+
+
+class Tokenizer(object):
+    """
+    Tokenizer object that handles text tokenization, command tokens, and type tokens.
+
+    Command tokens and text tokens are stored together in one mapping of size
+    `len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first
+    `len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`.
+
+    Token types are stored in a separate mapping of size `len(type_tokens)`.
+    """
+
+    def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
+        # set text tokenizer
+        self.text_tokenizer = text_tokenizer
+        if not hasattr(self, 'num_text_tokens'):
+            self.num_text_tokens = len(self.text_tokenizer)
+
+        # set command tokens
+        if command_tokens is None:
+            command_tokens = DEFAULT_COMMAND_TOKENS
+        self._command_tokens = command_tokens
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+        if not hasattr(self, 'num_command_tokens'):
+            self.num_command_tokens = len(self._command_tokens)
+        if not hasattr(self, 'num_tokens'):
+            self.num_tokens = self.num_command_tokens + self.num_text_tokens
+
+        # set type tokens
+        if type_tokens is None:
+            type_tokens = DEFAULT_TYPE_TOKENS
+        self.type_tokens = type_tokens
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+        if not hasattr(self, 'num_type_tokens'):
+            self.num_type_tokens = len(self.type_tokens)
+
+        # parse tokens and vocabs from tokenizer
+        self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
+        self._vocab = {t: Id for Id, t in self.command_id_map.items()}
+        self._vocab.update({t: Id + self.num_command_tokens for t, Id in self.text_tokenizer.vocab.items()})
+
+        self._text_tokens = list(self.text_tokenizer.tokens)
+        self._text_token_vocab = {t: Id + self.num_command_tokens for t, Id in self.text_tokenizer.vocab.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+    def __call__(self, text, process_fn=None):
+        """run preprocessing and encode text as Ids"""
+        return self.EncodeAsIds(text, process_fn=process_fn)
+
+    def __len__(self):
+        """total number of tokens"""
+        return self.num_tokens
+
+    def get_command(self, name):
+        """get command token corresponding to `name`"""
+        return self.command_name_map[name]
+
+    def get_type(self, name):
+        """get type token corresponding to `name`"""
+        return self.type_name_map[name]
+
+    @property
+    def tokens(self):
+        """list (or iterable) of all tokens for tokenizer"""
+        return self._tokens
+
+    @property
+    def vocab(self):
+        """dictionary mapping tokens to ids for tokenizer"""
+        return self._vocab
+
+    @property
+    def token_types(self):
+        """list (or iterable) of all token types for tokenizer"""
+        return self._token_types
+
+    @property
+    def token_type_vocab(self):
+        """dictionary mapping token types to ids for tokenizer"""
+        return self._token_type_vocab
+
+    @property
+    def command_tokens(self):
+        """list (or iterable) of all command tokens for tokenizer"""
+        return self._command_token_tokens
+
+    @property
+    def command_token_vocab(self):
+        """dictionary mapping command tokens to ids for tokenizer"""
+        return self._command_token_vocab
+
+    @property
+    def text_tokens(self):
+        """list (or iterable) of text tokens for text tokenizer"""
+        return self._text_tokens
+
+    @property
+    def text_token_vocab(self):
+        """dictionary mapping text tokens to ids for text tokenizer"""
+        return self._text_token_vocab
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """
+        encode text using text tokenizer and shift Id values for command tokens
+        """
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+
+        def split_on_token(tok_extended: CommandToken, text):
+            result = []
+            tok = tok_extended.token
+            split_text = text.split(tok)
+            for i, sub_text in enumerate(split_text):
+                # CommandToken can control whitespace stripping around them.
+                # We use them for GPT2 and Roberta to have different behavior depending on the special token
+                # Cf. https://github.com/huggingface/transformers/pull/2778
+                # and https://github.com/huggingface/transformers/issues/3788
+                # Strip white spaces on the right
+                if tok_extended.rstrip and i > 0:
+                    # A bit counter-intuitive but we strip the left of the string
+                    # since tok_extended.rstrip means the special token is eating all white spaces on its right
+                    sub_text = sub_text.lstrip()
+                # Strip white spaces on the left
+                if tok_extended.lstrip and i < len(split_text) - 1:
+                    sub_text = sub_text.rstrip()  # Opposite here
+
+                if i == 0 and not sub_text:
+                    result.append(tok)
+                elif i == len(split_text) - 1:
+                    if sub_text:
+                        result.append(sub_text)
+                    else:
+                        pass
+                else:
+                    if sub_text:
+                        result.append(sub_text)
+                    result.append(tok)
+            return result
+
+        def split_on_tokens(tok_list, text):
+            if not text.strip():
+                return []
+            if not tok_list:
+                return self.text_tokenizer.encode(text)
+
+            tokenized_text = []
+            text_list = [text]
+            for tok in tok_list:
+                tokenized_text = []
+                for sub_text in text_list:
+                    if sub_text not in self._command_token_tokens:
+                        tokenized_text.extend(split_on_token(tok, sub_text))
+                    else:
+                        tokenized_text.append(sub_text)
+                text_list = tokenized_text
+
+            return list(
+                itertools.chain.from_iterable(
+                    (
+                        self._encode(token) if token not in self._command_token_tokens else [
+                            self.command_token_map[token].Id] for token in tokenized_text
+                    )
+                )
+            )
+
+        no_split_tokens = self._command_tokens
+        Ids = split_on_tokens(no_split_tokens, processed_text)
+        tokenization = Tokenization(Ids, processed_text, text)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def _encode(self, text):
+        raise NotImplementedError
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """
+        encode text as tokens using text tokenizer
+        """
+        tokenization = self.text_tokenizer.EncodeAsTokens(text, process_fn=process_fn)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def IdToToken(self, Id, type_token=False):
+        """convert Id to token accounting for command and type tokens"""
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id < self.num_command_tokens:
+            return self.command_id_map[Id].token
+        return self.text_tokenizer.IdToToken(Id - self.num_command_tokens)
+
+    def TokenToId(self, token, type_token=False):
+        """convert token to Id accounting for command and type tokens"""
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        if token in self.command_token_map:
+            return self.command_token_map[token].Id
+        return self.text_tokenizer.TokenToId(token) + self.num_command_tokens
+
+    def DecodeIds(self, Ids, type_token=False):
+        """
+        convert Ids to tokens accounting for command and type tokens, tokens
+        are joined and returned as a string.
+        """
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        rtn_strs = []
+        current_str = []
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        for Id in Ids:
+            if isinstance(Id, CommandToken):
+                rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+                current_str = []
+                rtn_strs.append(Id.token)
+            elif Id < self.num_command_tokens:
+                rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+                current_str = []
+                rtn_strs.append(self.command_id_map[Id].token)
+            else:
+                current_str.append(Id - self.num_command_tokens)
+        if current_str != []:
+            rtn_strs.append(self.text_tokenizer.DecodeIds(current_str))
+        return ' '.join(rtn_strs)
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        """
+        convert tokens to a string accounting for command and type tokens.
+        """
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        rtn_strs = []
+        current_str = []
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        for t in Tokens:
+            if isinstance(t, CommandToken):
+                rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+                current_str = []
+                rtn_strs.append(t.token)
+            elif t in self.command_token_map:
+                rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+                current_str = []
+                rtn_strs.append(t)
+            else:
+                current_str.append(t)
+        if current_str != []:
+            rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
+        return ' '.join(rtn_strs)
+
+
+class TextTokenizer(object):
+    """
+    Interface for text tokenizer
+    """
+
+    def __init__(self):
+        if not hasattr(self, 'num_text_tokens'):
+            self.num_text_tokens = 0
+        if not hasattr(self, 'num_tokens'):
+            self.num_tokens = self.num_text_tokens
+
+    def __call__(self, text, process_fn=None):
+        return self.EncodeAsIds(text, process_fn)
+
+    def __len__(self):
+        return self.num_text_tokens
+
+    @property
+    def tokens(self):
+        """list (or iterable) of text tokens for text tokenizer"""
+        raise NotImplementedError('TextTokenizer tokens property not implemented')
+
+    @property
+    def vocab(self):
+        """dictionary mapping tokens to ids"""
+        raise NotImplementedError('TextTokenizer vocab property not implemented')
+
+    @staticmethod
+    def exists(model_path):
+        """check if the filepath for a text tokenizer exists"""
+        raise NotImplementedError('TextTokenizer exists method not implemented')
+
+    def Train(self, corpus):
+        """train a tokenizer on a data corpus and save model for future use"""
+        raise NotImplementedError('TextTokenizer Train not implemented')
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """
+        Preprocess text and encode as ids. Return a tokenization object with
+        original text, processed text, and id tokenization.
+        """
+        raise NotImplementedError('TextTokenizer EncodeAsIds not implemented')
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """
+        Preprocess text and encode as tokens. Return a tokenization object with
+        original text, processed text, and token tokenization.
+        """
+        raise NotImplementedError('TextTokenizer EncodeAsTokens not implemented')
+
+    def IdToToken(self, Id):
+        """Convert an Id to Token. Reverse lookup of self.vocab"""
+        raise NotImplementedError('TextTokenizer IdToToken not implemented')
+
+    def TokenToId(self, token):
+        """Convert a Token to Id. Lookup of self.vocab"""
+        raise NotImplementedError('TextTokenizer TokenToId not implemented')
+
+    def DecodeIds(self, Ids):
+        """Convert a list or tokenization object of Ids to a text string"""
+        raise NotImplementedError('TextTokenizer DecodeIds not implemented')
+
+    def DecodeTokens(self, Tokens):
+        """Convert a list or tokenization object of tokens to a text string"""
+        raise NotImplementedError('TextTokenizer DecodeTokens not implemented')
+
+
+class CharacterLevelTokenizer(TextTokenizer):
+    """
+    Text tokenizer for ASCII-256 Character Level Tokenization.
+    """
+
+    def __init__(self, **kwargs):
+        self.num_text_tokens = 256
+        super(CharacterLevelTokenizer, self).__init__()
+        self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
+        self._vocab = {t: i for i, t in enumerate(self._tokens)}
+
+    def __len__(self):
+        return 256
+
+    @staticmethod
+    def exists(model_path):
+        return True
+
+    def Train(self, corpus):
+        pass
+
+    @property
+    def tokens(self):
+        return self._tokens
+
+    @property
+    def vocab(self):
+        return self._vocab
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """convert text to ascii 256 Ids"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+            processed_text = str(processed_text)
+        tokens = [self.TokenToId(c) for c in processed_text]
+        return Tokenization(tokens, processed_text, text)
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """convert text to ascii 256 characters"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        processed_text = str(processed_text)
+        tokens = [c for c in processed_text]
+        return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id):
+        """ascii index to character"""
+        return chr(Id)
+
+    def TokenToId(self, token):
+        """ascii character to index"""
+        return ord(token)
+
+    def DecodeIds(self, Ids):
+        """converts ascii ids to tokens before joining them into text"""
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        return ''.join([self.IdToToken(tok) for tok in Ids])
+
+    def DecodeTokens(self, Tokens):
+        """just concatenates ascii tokens into text"""
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return ''.join(Tokens)
+
+
+MAX_SENTENCEPIECE_SENTENCES = 100000000
+
+
+def get_corpus_freq(dataset, filepath, filetype='tsv'):
+    """
+    Take corpus, split it into sentences, and extract word frequencies.
+    Write frequencies to `filepath` as a tsv. Only write the first
+    MAX_SENTENCEPIECE_SENTENCES most common words to the file.
+    """
+    nltk.download('punkt', download_dir="./nltk")
+    if filetype == 'tsv':
+        delimiter = '\t'
+    else:
+        delimiter = ','
+
+    print("compute corpus frequency\n", flush=True)
+
+    total_sentence_count = 0
+    maxlen = 0
+    freqs = {}
+    for entry in dataset:
+        if isinstance(entry, dict):
+            entry = entry['text']
+        lines = entry.strip().split('\n')
+        for line in lines:
+            sentences = nltk_tokenize.sent_tokenize(line)
+            total_sentence_count += len(sentences)
+            for sentence in sentences:
+                maxlen = max(len(line), maxlen)
+                for word in sentence.split():
+                    if word not in freqs:
+                        freqs[word] = 0
+                    freqs[word] += 1
+
+    print("length of freqs before truncating " + str(len(freqs)), flush=True)
+    print("file path for freq " + str(filepath), flush=True)
+
+    freqs_sorted = {}
+    counter = 0
+    for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
+        if counter >= MAX_SENTENCEPIECE_SENTENCES:
+            break
+        counter += 1
+        freqs_sorted[word] = count
+
+    print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
+
+    with open(filepath, 'w') as f:
+        writer = csv.writer(f, delimiter=delimiter)
+        for k, v in freqs_sorted.items():
+            writer.writerow([str(k), str(v)])
+
+    return total_sentence_count, maxlen
+
+
+class SentencePieceTokenizer(TextTokenizer):
+    """Trains and uses sentencepiece for text tokenization"""
+
+    def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0,
+                 **kwargs):
+        self.character_coverage = character_coverage
+        self.model_type = model_type.lower()
+        self.spm_model = model_path
+        self.num_text_tokens = vocab_size
+        make_train = not SentencePieceTokenizer.exists(self.spm_model)
+        if make_train:
+            assert corpus is not None and self.num_text_tokens is not None
+            self.Train(corpus, self.num_text_tokens)
+        self._tokens = []
+        self._vocab = {}
+        self.load_spm_model()
+        super(SentencePieceTokenizer, self).__init__()
+
+    def __len__(self):
+        return self.num_text_tokens
+
+    @property
+    def tokens(self):
+        return self._tokens
+
+    @property
+    def vocab(self):
+        return self._vocab
+
+    @staticmethod
+    def exists(model_path):
+        if model_path is None:
+            return False
+        # check if path exists
+        dne = not os.path.exists(model_path)
+        # check if path.model exists
+        if dne and not model_path.endswith('.model'):
+            dne = not os.path.exists(model_path + '.model')
+        return not dne
+
+    def load_spm_model(self):
+        """load sentencepiece model and parse vocab"""
+        if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
+            self.spm_model = self.spm_model + '.model'
+        self.sp = spm.SentencePieceProcessor()
+        self.sp.Load(self.spm_model)
+        self.vocab_size = self.num_text_tokens = len(self.sp)
+        self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
+        self._vocab = {t: i for i, t in enumerate(self._tokens)}
+
+    def Train(self, corpus, num_text_tokens):
+        """train sentencepiece model on corpus using word frequencies"""
+        self.num_text_tokens = num_text_tokens
+        use_model_path = self.spm_model
+        random_hash = str(random.randint(0, 2147483647))
+        if use_model_path is None:
+            use_model_path = random_hash
+        if use_model_path.endswith('.model'):
+            use_model_path = use_model_path[:use_model_path.rfind('.model')]
+        input_path = use_model_path + '.tsv.' + random_hash
+        line_count, maxlenline = get_corpus_freq(corpus, input_path)
+        line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
+        print('line count used as input_sentence_size ', line_count, flush=True)
+        print('training sentencepiece model', flush=True)
+        train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \
+                       + ' --model_type={model_type} --character_coverage={character_coverage} ' \
+                       + '--input_sentence_size={input_sentence_size} ' \
+                       + '--input_format=tsv'
+        train_string = train_string.format(file_path=input_path, model_prefix=use_model_path,
+                                           vocab_size=num_text_tokens,
+                                           model_type=self.model_type, character_coverage=self.character_coverage,
+                                           input_sentence_size=int(line_count))  # , #)#,
+        print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True)
+        spm.SentencePieceTrainer.Train(train_string)
+        os.remove(input_path)
+        self.spm_model = use_model_path + '.model'
+        print('sentencepiece model written to ' + self.spm_model, flush=True)
+
+    def EncodeAsIds(self, text, process_fn=None):
+        """convert text to sentencepiece Ids"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.sp.EncodeAsIds(processed_text)
+        return Tokenization(tokens, processed_text, text)
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """convert text to sentencepiece tokens"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.sp.EncodeAsTokens(processed_text)
+        return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id):
+        """convert Id to sentencpiece token"""
+        return self.sp.IdToPiece(Id)
+
+    def TokenToId(self, token):
+        """convert sentencpiece token to Id"""
+        return self.sp.PieceToId(token)
+
+    def DecodeIds(self, Ids):
+        """converts ids to a text string"""
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        return self.sp.DecodeIds(Ids)
+
+    def DecodeTokens(self, Tokens):
+        """converts sentencepiece tokens to a text string"""
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return self.sp.DecodeTokens(Tokens)
+
+
+class BertWordPieceTokenizer(Tokenizer):
+    """
+    Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
+    in BERT training. Default to bert-large-uncased tokenizer.
+    """
+
+    def __init__(self, tokenizer_model_type=None, cache_dir=None, add_block_symbols=False, add_sentinel_token=0,
+                 add_task_mask=False, add_decoder_mask=False, **kwargs):
+        # default to bert-large-uncased tokenizer
+        if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
+            tokenizer_model_type = 'bert-large-uncased'
+        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+            print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir)
+        do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
+        self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case,
+                                                            cache_dir=cache_dir)
+        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
+            print('loaded', tokenizer_model_type)
+        # disable max len warnings by increasing max len
+        self.text_tokenizer.max_len = int(1e12)
+
+        # set command tokens from wordpiece tokenizer values
+        self.num_command_tokens = 6
+        self.num_tokens = len(self.text_tokenizer.vocab)
+        self.num_text_tokens = self.num_tokens - 5
+        self.num_type_tokens = 2
+
+        self._command_tokens = [
+            CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']),
+            CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']),
+            CommandToken('MASK', '[MASK]', self.text_tokenizer.vocab['[MASK]']),
+            CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']),
+            CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']),
+            CommandToken('eos', '[PAD]', self.text_tokenizer.vocab['[PAD]']),
+        ]
+        if add_block_symbols:
+            self._command_tokens.extend([
+                CommandToken('sop', '<|startofpiece|>', self.num_tokens),
+                CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1)
+            ])
+            self.num_tokens += 2
+            self.num_command_tokens += 2
+            if add_task_mask:
+                self._command_tokens.extend([
+                    CommandToken('gMASK', '[gMASK]', self.num_tokens),
+                    CommandToken('sMASK', '[sMASK]', self.num_tokens + 1)
+                ])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+            if add_decoder_mask:
+                self._command_tokens.extend([
+                    CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)
+                ])
+                self.num_tokens += 1
+                self.num_command_tokens += 1
+        if add_sentinel_token > 0:
+            for i in range(1, add_sentinel_token):
+                self._command_tokens.extend([CommandToken(f'MASK{i}', f'[MASK{i}]', self.num_tokens),
+                                             CommandToken(f'sop{i}', f'<|startofpiece{i}|>', self.num_tokens + 1)])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+        # set type tokens
+        self.type_tokens = [
+            TypeToken('str0', '<str0>', 0),
+            TypeToken('str1', '<str1>', 1),
+        ]
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+        # parse tokens and vocabs from tokenizer
+
+        self._tokens = list(self.text_tokenizer.vocab.keys())
+        self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
+
+        self._text_tokens = list(self._tokens)
+        self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+    def _encode(self, text):
+        tokens = self.text_tokenizer.tokenize(text)
+        ids = self.text_tokenizer.convert_tokens_to_ids(tokens)
+        return ids
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        """convert wordpiece token to Id"""
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.text_tokenizer.tokenize(processed_text)
+        return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id, type_token=False):
+        """convert Id to sentencpiece token"""
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id in self.command_id_map:
+            return self.command_id_map[Id].token
+        return self.text_tokenizer.ids_to_tokens[Id]
+
+    @staticmethod
+    def clean_up_tokenization(out_string: str) -> str:
+        """
+        Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
+
+        Args:
+            out_string (:obj:`str`): The text to clean up.
+
+        Returns:
+            :obj:`str`: The cleaned-up string.
+        """
+        out_string = (
+            out_string.replace(" .", ".")
+                .replace(" ?", "?")
+                .replace(" !", "!")
+                .replace(" ,", ",")
+                .replace(" ' ", "'")
+                .replace(" n't", "n't")
+                .replace(" 'm", "'m")
+                .replace(" 's", "'s")
+                .replace(" 've", "'ve")
+                .replace(" 're", "'re")
+        )
+        return out_string
+
+    def TokenToId(self, token, type_token=False):
+        """convert sentencpiece token to Id"""
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        return self.text_tokenizer.vocab[token]
+
+    def DecodeIds(self, Ids, type_token=False):
+        """converts ids to wordpiece tokens and joins them as a text string"""
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        Tokens = []
+        for Id in Ids:
+            if Id in self.command_id_map:
+                Tokens.append(self.command_id_map[Id].token)
+            elif Id in self.text_tokenizer.ids_to_tokens:
+                Tokens.append(self.text_tokenizer.ids_to_tokens[Id])
+        new_tokens = []
+        for token in Tokens:
+            if token.startswith('##') and len(new_tokens) > 0:
+                new_tokens[-1] += token[2:]
+            else:
+                new_tokens.append(token)
+        output = ' '.join(new_tokens)
+        output = self.clean_up_tokenization(output)
+        return output
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        """converts wordpiece tokens to a text string"""
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return ' '.join(Tokens)
+
+
+class GPT2BPETokenizer(Tokenizer):
+    def __init__(self, model_type_or_path, cache_dir=None, add_block_symbols=False, add_task_mask=False,
+                 add_decoder_mask=False, **kwargs):
+        self.text_tokenizer = GPT2Tokenizer.from_pretrained(model_type_or_path,
+                                                            cache_dir=cache_dir)
+
+        # disable max len warnings by increasing max len
+        self.text_tokenizer.max_len = int(1e12)
+        self.num_tokens = len(self.text_tokenizer.encoder)
+        self.num_type_tokens = 2
+        if model_type_or_path.startswith('roberta'):
+            self.num_command_tokens = 6
+            self.num_text_tokens = self.num_tokens - 3
+            self._command_tokens = [
+                CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['</s>']),
+                CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['</s>']),
+                CommandToken('sep', '[SEP]', self.text_tokenizer.encoder['<pad>']),
+                CommandToken('ENC', '[CLS]', self.text_tokenizer.encoder['<s>']),
+                CommandToken('MASK', '[MASK]', self.text_tokenizer.encoder['<mask>'], lstrip=True),
+                CommandToken('unk', '[UNK]', self.text_tokenizer.encoder['<unk>'])
+            ]
+            if add_block_symbols:
+                self._command_tokens.extend([
+                    CommandToken('sop', '<|startofpiece|>', self.num_tokens),
+                    CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1)
+                ])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+        else:
+            self.num_command_tokens = 2
+            self.num_text_tokens = self.num_tokens - 1
+            self._command_tokens = [
+                CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']),
+                CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>'])
+            ]
+            if add_block_symbols:
+                self._command_tokens.extend([
+                    CommandToken('sop', '<|startofpiece|>', self.num_tokens),
+                    CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1),
+                    CommandToken('ENC', '[CLS]', self.num_tokens + 2),
+                    CommandToken('MASK', '[MASK]', self.num_tokens + 3, lstrip=True),
+                    CommandToken('sep', '[SEP]', self.num_tokens + 4),
+                    CommandToken('unk', '[UNK]', self.num_tokens + 5)
+                ])
+                self.num_tokens += 6
+                self.num_command_tokens += 6
+        if add_block_symbols:
+            if add_task_mask:
+                self._command_tokens.extend([
+                    CommandToken('gMASK', '[gMASK]', self.num_tokens, lstrip=True),
+                    CommandToken('sMASK', '[sMASK]', self.num_tokens + 1, lstrip=True)
+                ])
+                self.num_tokens += 2
+                self.num_command_tokens += 2
+            if add_decoder_mask:
+                self._command_tokens.extend([
+                    CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens)
+                ])
+                self.num_tokens += 1
+                self.num_command_tokens += 1
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+        self.type_tokens = [
+            TypeToken('str0', '<str0>', 0),
+            TypeToken('str1', '<str1>', 1),
+        ]
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+        self._tokens = list(self.text_tokenizer.encoder.keys())
+        self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
+
+        self._text_tokens = list(self._tokens)
+        self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+        for idx, tok in self.command_id_map.items():
+            self.text_tokenizer.decoder[idx] = tok.token
+
+    def EncodeAsIds(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+
+        def split_on_token(tok_extended: CommandToken, text):
+            result = []
+            tok = tok_extended.token
+            split_text = text.split(tok)
+            for i, sub_text in enumerate(split_text):
+                # CommandToken can control whitespace stripping around them.
+                # We use them for GPT2 and Roberta to have different behavior depending on the special token
+                # Cf. https://github.com/huggingface/transformers/pull/2778
+                # and https://github.com/huggingface/transformers/issues/3788
+                # Strip white spaces on the right
+                if tok_extended.rstrip and i > 0:
+                    # A bit counter-intuitive but we strip the left of the string
+                    # since tok_extended.rstrip means the special token is eating all white spaces on its right
+                    sub_text = sub_text.lstrip()
+                # Strip white spaces on the left
+                if tok_extended.lstrip and i < len(split_text) - 1:
+                    sub_text = sub_text.rstrip()  # Opposite here
+
+                if i == 0 and not sub_text:
+                    result.append(tok)
+                elif i == len(split_text) - 1:
+                    if sub_text:
+                        result.append(sub_text)
+                    else:
+                        pass
+                else:
+                    if sub_text:
+                        result.append(sub_text)
+                    result.append(tok)
+            return result
+
+        def split_on_tokens(tok_list, text):
+            if not text.strip():
+                return []
+            if not tok_list:
+                return self.text_tokenizer.encode(text)
+
+            tokenized_text = []
+            text_list = [text]
+            for tok in tok_list:
+                tokenized_text = []
+                for sub_text in text_list:
+                    if sub_text not in self._command_token_tokens:
+                        tokenized_text.extend(split_on_token(tok, sub_text))
+                    else:
+                        tokenized_text.append(sub_text)
+                text_list = tokenized_text
+
+            return list(
+                itertools.chain.from_iterable(
+                    (
+                        self.text_tokenizer.encode(token) if token not in self._command_token_tokens else [
+                            self.command_token_map[token].Id] for token in tokenized_text
+                    )
+                )
+            )
+
+        no_split_tokens = self._command_tokens
+        Ids = split_on_tokens(no_split_tokens, processed_text)
+        tokenization = Tokenization(Ids, processed_text, text)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def _encode(self, text):
+        return self.text_tokenizer.encode(text)
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = []
+        for token in re.findall(self.text_tokenizer.pat, processed_text):
+            token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
+            tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
+        tokenization = Tokenization(tokens, processed_text, text, asIds=False)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+
+    def DecodeAsTokens(self, Ids):
+        return [self.IdToToken(x) for x in Ids]
+
+    def IdToToken(self, Id, type_token=False):
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id in self.command_id_map:
+            return self.command_id_map[Id].token
+        return self.text_tokenizer.decoder[Id]
+
+    def TokenToId(self, token, type_token=False):
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        return self.text_tokenizer.encoder[token]
+
+    def DecodeIds(self, Ids, type_token=False):
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        return self.text_tokenizer.decode(Ids)
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
+
+
+class ChineseSPTokenizer(Tokenizer):
+    def __init__(self, add_block_symbols=False, **kwargs):
+        self.text_tokenizer = sp_tokenizer.from_pretrained()
+
+        self.num_command_tokens = 2
+        self.num_text_tokens = self.text_tokenizer.sp.vocab_size()
+        self.num_tokens = self.num_text_tokens + 1
+        self.num_type_tokens = 2
+
+        self._command_tokens = [
+            CommandToken('pad', '<|endoftext|>', self.num_text_tokens),
+            CommandToken('eos', '<|endoftext|>', self.num_text_tokens),
+        ]
+        if add_block_symbols:
+            self._command_tokens.extend([
+                CommandToken('sop', '<|startofpiece|>', self.num_text_tokens + 1),
+                CommandToken('eop', '<|endofpiece|>', self.num_text_tokens + 2)
+            ])
+            self.num_tokens += 2
+            self.num_command_tokens += 2
+        self.command_name_map = {tok.name: tok for tok in self._command_tokens}
+        self.command_token_map = {tok.token: tok for tok in self._command_tokens}
+        self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
+
+        self.type_tokens = [
+            TypeToken('str0', '<str0>', 0),
+            TypeToken('str1', '<str1>', 1),
+        ]
+        self.type_name_map = {tok.name: tok for tok in self.type_tokens}
+        self.type_token_map = {tok.token: tok for tok in self.type_tokens}
+        self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
+
+        # self._tokens = list(self.text_tokenizer.encoder.keys())
+        # self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
+        #
+        # self._text_tokens = list(self._tokens)
+        # self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
+
+        self._command_token_tokens = list(self.command_token_map.keys())
+        self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
+
+        self._token_types = list(self.type_token_map.keys())
+        self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
+
+    def _encode(self, text):
+        ids = self.text_tokenizer.encode(text)
+        return ids
+
+    def EncodeAsTokens(self, text, process_fn=None):
+        processed_text = text
+        if process_fn is not None:
+            processed_text = process_fn(processed_text)
+        tokens = self.text_tokenizer.tokenize(processed_text)
+        tokenization = Tokenization(tokens, processed_text, text, asIds=False)
+        tokenization.set_command_tokens(self._command_tokens)
+        return tokenization
+        # return Tokenization(tokens, processed_text, text, asIds=False)
+
+    def IdToToken(self, Id, type_token=False):
+        if isinstance(Id, (TypeToken, CommandToken)):
+            return Id.token
+        if type_token:
+            return self.type_id_map[Id].token
+        if Id in self.command_id_map:
+            return self.command_id_map[Id].token
+        elif Id in self.type_id_map:
+            return self.type_id_map[Id].token
+        else:
+            return self.text_tokenizer.convert_id_to_token(Id)
+
+    def TokenToId(self, token, type_token=False):
+        if isinstance(token, (TypeToken, CommandToken)):
+            return token.Id
+        if type_token:
+            return self.type_token_map[token].Id
+        return self.text_tokenizer.convert_token_to_id(token)
+
+    def DecodeIds(self, Ids, type_token=False):
+        if type_token:
+            return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
+        if isinstance(Ids, Tokenization):
+            Ids = Ids.tokenization
+        try:
+            first_eos = Ids.index(self.get_command('eos').Id)
+            eos_count = len(Ids) - first_eos
+            Ids = Ids[:first_eos]
+        except ValueError:
+            eos_count = 0
+        return " ".join((self.text_tokenizer.decode(Ids), *(['<|endoftext|>'] * eos_count)))
+
+    def DecodeTokens(self, Tokens, type_token=False):
+        if type_token:
+            return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens)
+        if isinstance(Tokens, Tokenization):
+            Tokens = Tokens.tokenization
+        return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens])
diff --git a/tokenization/text/tokenization_gpt2.py b/tokenization/text/tokenization_gpt2.py
new file mode 100644
index 0000000..318d920
--- /dev/null
+++ b/tokenization/text/tokenization_gpt2.py
@@ -0,0 +1,310 @@
+# coding=utf-8
+# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes for OpenAI GPT."""
+from __future__ import (absolute_import, division, print_function,
+                        unicode_literals)
+
+import sys
+import json
+import logging
+import os
+import regex as re
+from io import open
+
+try:
+    from functools import lru_cache
+except ImportError:
+    # Just a dummy decorator to get the checks to run on python2
+    # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
+    def lru_cache():
+        return lambda func: func
+
+from ..file_utils import cached_path
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+    'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-vocab.json",
+    "roberta": "pretrained/pytorch_pretrained_bert/roberta-vocab.json"
+}
+PRETRAINED_MERGES_ARCHIVE_MAP = {
+    'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-merges.txt",
+    "roberta": "pretrained/pytorch_pretrained_bert/roberta-merges.txt"
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+    'gpt2': 1024,
+}
+VOCAB_NAME = 'vocab.json'
+MERGES_NAME = 'merges.txt'
+SPECIAL_TOKENS_NAME = 'special_tokens.txt'
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    _chr = unichr if sys.version_info[0] == 2 else chr
+    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+    cs = bs[:]
+    n = 0
+    for b in range(2**8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2**8+n)
+            n += 1
+    cs = [_chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+class GPT2Tokenizer(object):
+    """
+    GPT-2 BPE tokenizer. Peculiarities:
+        - Byte-level BPE
+    """
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+        """
+        Instantiate a PreTrainedBertModel from a pre-trained model file.
+        Download and cache the pre-trained model file if needed.
+        """
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+            vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+            merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
+            special_tokens_file = None
+        else:
+            vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
+            merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
+            special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
+            if not os.path.exists(special_tokens_file):
+                special_tokens_file = None
+            else:
+                logger.info("loading special tokens file {}".format(special_tokens_file))
+        # redirect to the cache, if necessary
+        # try:
+        #     resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+        #     resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
+        # except EnvironmentError:
+        #     logger.error(
+        #         "Model name '{}' was not found in model name list ({}). "
+        #         "We assumed '{}' was a path or url but couldn't find files {} and {} "
+        #         "at this path or url.".format(
+        #             pretrained_model_name_or_path,
+        #             ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+        #             pretrained_model_name_or_path,
+        #             vocab_file, merges_file))
+        #     return None
+        # if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
+        #     logger.info("loading vocabulary file {}".format(vocab_file))
+        #     logger.info("loading merges file {}".format(merges_file))
+        # else:
+        #     logger.info("loading vocabulary file {} from cache at {}".format(
+        #         vocab_file, resolved_vocab_file))
+        #     logger.info("loading merges file {} from cache at {}".format(
+        #         merges_file, resolved_merges_file))
+        resolved_vocab_file = vocab_file
+        resolved_merges_file = merges_file
+        logger.info("loading vocabulary file {}".format(vocab_file))
+        logger.info("loading merges file {}".format(merges_file))
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
+            # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+            # than the number of positional embeddings
+            max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
+            kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
+        # Instantiate tokenizer.
+        if special_tokens_file and 'special_tokens' not in kwargs:
+            special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
+        else:
+            special_tokens = kwargs.pop('special_tokens', [])
+        tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
+        return tokenizer
+
+    def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
+        self.max_len = max_len if max_len is not None else int(1e12)
+        self.encoder = json.load(open(vocab_file))
+        self.decoder = {v:k for k,v in self.encoder.items()}
+        self.errors = errors # how to handle errors in decoding
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
+        bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
+        bpe_merges = [tuple(merge.split()) for merge in bpe_data]
+        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
+        self.cache = {}
+
+        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
+        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
+
+        self.special_tokens = {}
+        self.special_tokens_decoder = {}
+        self.set_special_tokens(special_tokens)
+
+    def __len__(self):
+        return len(self.encoder) + len(self.special_tokens)
+
+    def set_special_tokens(self, special_tokens):
+        """ Add a list of additional tokens to the encoder.
+            The additional tokens are indexed starting from the last index of the
+            current vocabulary in the order of the `special_tokens` list.
+        """
+        if not special_tokens:
+            self.special_tokens = {}
+            self.special_tokens_decoder = {}
+            return
+        self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
+        self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
+        logger.info("Special tokens {}".format(self.special_tokens))
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token
+
+        while True:
+            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word)-1 and word[i+1] == second:
+                    new_word.append(first+second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = ' '.join(word)
+        self.cache[token] = word
+        return word
+
+    def tokenize(self, text):
+        """ Tokenize a string. """
+        bpe_tokens = []
+        for token in re.findall(self.pat, text):
+            if sys.version_info[0] == 2:
+                token = ''.join(self.byte_encoder[ord(b)] for b in token)
+            else:
+                token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+            bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
+        return bpe_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        """ Converts a sequence of tokens into ids using the vocab. """
+        ids = []
+        if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
+            if tokens in self.special_tokens:
+                return self.special_tokens[tokens]
+            else:
+                return self.encoder.get(tokens, 0)
+        for token in tokens:
+            if token in self.special_tokens:
+                ids.append(self.special_tokens[token])
+            else:
+                ids.append(self.encoder.get(token, 0))
+        if len(ids) > self.max_len:
+            logger.warning(
+                "Token indices sequence length is longer than the specified maximum "
+                " sequence length for this OpenAI GPT model ({} > {}). Running this"
+                " sequence through the model will result in indexing errors".format(len(ids), self.max_len)
+            )
+        return ids
+
+    def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
+        """Converts a sequence of ids in BPE tokens using the vocab."""
+        tokens = []
+        for i in ids:
+            if i in self.special_tokens_decoder:
+                if not skip_special_tokens:
+                    tokens.append(self.special_tokens_decoder[i])
+            else:
+                tokens.append(self.decoder[i])
+        return tokens
+
+    def encode(self, text):
+        return self.convert_tokens_to_ids(self.tokenize(text))
+
+    def decode(self, tokens):
+        text = ''.join([self.decoder[token] for token in tokens])
+        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
+        return text
+
+    def save_vocabulary(self, vocab_path):
+        """Save the tokenizer vocabulary and merge files to a directory."""
+        if not os.path.isdir(vocab_path):
+            logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
+            return
+        vocab_file = os.path.join(vocab_path, VOCAB_NAME)
+        merge_file = os.path.join(vocab_path, MERGES_NAME)
+        special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
+
+        with open(vocab_file, 'w', encoding='utf-8') as f:
+            f.write(json.dumps(self.encoder, ensure_ascii=False))
+
+        index = 0
+        with open(merge_file, "w", encoding="utf-8") as writer:
+            writer.write(u'#version: 0.2\n')
+            for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
+                if index != token_index:
+                    logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
+                                   " Please check that the tokenizer is not corrupted!".format(merge_file))
+                    index = token_index
+                writer.write(' '.join(bpe_tokens) + u'\n')
+                index += 1
+
+        index = len(self.encoder)
+        with open(special_tokens_file, 'w', encoding='utf-8') as writer:
+            for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
+                if index != token_index:
+                    logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
+                                   " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
+                    index = token_index
+                writer.write(token + u'\n')
+                index += 1
+
+        return vocab_file, merge_file, special_tokens_file
diff --git a/tokenization/text/wordpiece.py b/tokenization/text/wordpiece.py
new file mode 100644
index 0000000..e5ce2e2
--- /dev/null
+++ b/tokenization/text/wordpiece.py
@@ -0,0 +1,390 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py"""
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import collections
+import logging
+import os
+import unicodedata
+from io import open
+
+from ..file_utils import cached_path
+
+logger = logging.getLogger(__name__)
+
+PRETRAINED_VOCAB_ARCHIVE_MAP = {
+    'bert-base-uncased': "pretrained/pytorch_pretrained_bert/bert-base-uncased-vocab.txt",
+    'bert-large-uncased': "pretrained/pytorch_pretrained_bert/bert-large-uncased-vocab.txt",
+    'bert-base-cased': "pretrained/pytorch_pretrained_bert/bert-base-cased-vocab.txt",
+    'bert-large-cased': "pretrained/pytorch_pretrained_bert/bert-large-cased-vocab.txt",
+    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
+    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
+    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
+}
+PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
+    'bert-base-uncased': 512,
+    'bert-large-uncased': 512,
+    'bert-base-cased': 512,
+    'bert-large-cased': 512,
+    'bert-base-multilingual-uncased': 512,
+    'bert-base-multilingual-cased': 512,
+    'bert-base-chinese': 512,
+}
+VOCAB_NAME = 'vocab.txt'
+
+
+def load_vocab(vocab_file):
+    """Loads a vocabulary file into a dictionary."""
+    vocab = collections.OrderedDict()
+    index = 0
+    with open(vocab_file, "r", encoding="utf-8") as reader:
+        while True:
+            token = reader.readline()
+            if not token:
+                break
+            token = token.strip()
+            vocab[token] = index
+            index += 1
+    return vocab
+
+
+def whitespace_tokenize(text):
+    """Runs basic whitespace cleaning and splitting on a piece of text."""
+    text = text.strip()
+    if not text:
+        return []
+    tokens = text.split()
+    return tokens
+
+
+class BertTokenizer(object):
+    """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
+
+    def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
+                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+        """Constructs a BertTokenizer.
+
+        Args:
+          vocab_file: Path to a one-wordpiece-per-line vocabulary file
+          do_lower_case: Whether to lower case the input
+                         Only has an effect when do_wordpiece_only=False
+          do_basic_tokenize: Whether to do basic tokenization before wordpiece.
+          max_len: An artificial maximum length to truncate tokenized sequences to;
+                         Effective maximum length is always the minimum of this
+                         value (if specified) and the underlying BERT model's
+                         sequence length.
+          never_split: List of tokens which will never be split during tokenization.
+                         Only has an effect when do_wordpiece_only=False
+        """
+        if not os.path.isfile(vocab_file):
+            raise ValueError(
+                "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
+                "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
+        self.vocab = load_vocab(vocab_file)
+        self.ids_to_tokens = collections.OrderedDict(
+            [(ids, tok) for tok, ids in self.vocab.items()])
+        self.do_basic_tokenize = do_basic_tokenize
+        if do_basic_tokenize:
+          self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
+                                                never_split=never_split)
+        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+        self.max_len = max_len if max_len is not None else int(1e12)
+
+    def tokenize(self, text):
+        if self.do_basic_tokenize:
+          split_tokens = []
+          for token in self.basic_tokenizer.tokenize(text):
+              for sub_token in self.wordpiece_tokenizer.tokenize(token):
+                  split_tokens.append(sub_token)
+        else:
+          split_tokens = self.wordpiece_tokenizer.tokenize(text)
+        return split_tokens
+
+    def convert_tokens_to_ids(self, tokens):
+        """Converts a sequence of tokens into ids using the vocab."""
+        ids = []
+        for token in tokens:
+            ids.append(self.vocab[token])
+        if len(ids) > self.max_len:
+            logger.warning(
+                "Token indices sequence length is longer than the specified maximum "
+                " sequence length for this BERT model ({} > {}). Running this"
+                " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
+            )
+        return ids
+
+    def convert_ids_to_tokens(self, ids):
+        """Converts a sequence of ids in wordpiece tokens using the vocab."""
+        tokens = []
+        for i in ids:
+            tokens.append(self.ids_to_tokens[i])
+        return tokens
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
+        """
+        Instantiate a PreTrainedBertModel from a pre-trained model file.
+        Download and cache the pre-trained model file if needed.
+        """
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
+            vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
+        else:
+            vocab_file = pretrained_model_name_or_path
+        if os.path.isdir(vocab_file):
+            vocab_file = os.path.join(vocab_file, VOCAB_NAME)
+        # redirect to the cache, if necessary
+        try:
+            resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
+        except EnvironmentError:
+            logger.error(
+                "Model name '{}' was not found in model name list ({}). "
+                "We assumed '{}' was a path or url but couldn't find any file "
+                "associated to this path or url.".format(
+                    pretrained_model_name_or_path,
+                    ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
+                    vocab_file))
+            return None
+        if resolved_vocab_file == vocab_file:
+            logger.info("loading vocabulary file {}".format(vocab_file))
+        else:
+            logger.info("loading vocabulary file {} from cache at {}".format(
+                vocab_file, resolved_vocab_file))
+        if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
+            # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
+            # than the number of positional embeddings
+            max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
+            kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
+        # Instantiate tokenizer.
+        tokenizer = cls(resolved_vocab_file, *inputs, **kwargs)
+        return tokenizer
+
+
+class BasicTokenizer(object):
+    """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+    def __init__(self,
+                 do_lower_case=True,
+                 never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+        """Constructs a BasicTokenizer.
+
+        Args:
+          do_lower_case: Whether to lower case the input.
+        """
+        self.do_lower_case = do_lower_case
+        self.never_split = never_split
+
+    def tokenize(self, text):
+        """Tokenizes a piece of text."""
+        text = self._clean_text(text)
+        # This was added on November 1st, 2018 for the multilingual and Chinese
+        # models. This is also applied to the English models now, but it doesn't
+        # matter since the English models were not trained on any Chinese data
+        # and generally don't have any Chinese data in them (there are Chinese
+        # characters in the vocabulary because Wikipedia does have some Chinese
+        # words in the English Wikipedia.).
+        text = self._tokenize_chinese_chars(text)
+        orig_tokens = whitespace_tokenize(text)
+        split_tokens = []
+        for token in orig_tokens:
+            if self.do_lower_case and token not in self.never_split:
+                token = token.lower()
+                token = self._run_strip_accents(token)
+            split_tokens.extend(self._run_split_on_punc(token))
+
+        output_tokens = whitespace_tokenize(" ".join(split_tokens))
+        return output_tokens
+
+    def _run_strip_accents(self, text):
+        """Strips accents from a piece of text."""
+        text = unicodedata.normalize("NFD", text)
+        output = []
+        for char in text:
+            cat = unicodedata.category(char)
+            if cat == "Mn":
+                continue
+            output.append(char)
+        return "".join(output)
+
+    def _run_split_on_punc(self, text):
+        """Splits punctuation on a piece of text."""
+        if text in self.never_split:
+            return [text]
+        chars = list(text)
+        i = 0
+        start_new_word = True
+        output = []
+        while i < len(chars):
+            char = chars[i]
+            if _is_punctuation(char):
+                output.append([char])
+                start_new_word = True
+            else:
+                if start_new_word:
+                    output.append([])
+                start_new_word = False
+                output[-1].append(char)
+            i += 1
+
+        return ["".join(x) for x in output]
+
+    def _tokenize_chinese_chars(self, text):
+        """Adds whitespace around any CJK character."""
+        output = []
+        for char in text:
+            cp = ord(char)
+            if self._is_chinese_char(cp):
+                output.append(" ")
+                output.append(char)
+                output.append(" ")
+            else:
+                output.append(char)
+        return "".join(output)
+
+    def _is_chinese_char(self, cp):
+        """Checks whether CP is the codepoint of a CJK character."""
+        # This defines a "chinese character" as anything in the CJK Unicode block:
+        #   https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+        #
+        # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+        # despite its name. The modern Korean Hangul alphabet is a different block,
+        # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+        # space-separated words, so they are not treated specially and handled
+        # like the all of the other languages.
+        if ((cp >= 0x4E00 and cp <= 0x9FFF) or  #
+                (cp >= 0x3400 and cp <= 0x4DBF) or  #
+                (cp >= 0x20000 and cp <= 0x2A6DF) or  #
+                (cp >= 0x2A700 and cp <= 0x2B73F) or  #
+                (cp >= 0x2B740 and cp <= 0x2B81F) or  #
+                (cp >= 0x2B820 and cp <= 0x2CEAF) or
+                (cp >= 0xF900 and cp <= 0xFAFF) or  #
+                (cp >= 0x2F800 and cp <= 0x2FA1F)):  #
+            return True
+
+        return False
+
+    def _clean_text(self, text):
+        """Performs invalid character removal and whitespace cleanup on text."""
+        output = []
+        for char in text:
+            cp = ord(char)
+            if cp == 0 or cp == 0xfffd or _is_control(char):
+                continue
+            if _is_whitespace(char):
+                output.append(" ")
+            else:
+                output.append(char)
+        return "".join(output)
+
+
+class WordpieceTokenizer(object):
+    """Runs WordPiece tokenization."""
+
+    def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
+        self.vocab = vocab
+        self.unk_token = unk_token
+        self.max_input_chars_per_word = max_input_chars_per_word
+
+    def tokenize(self, text):
+        """Tokenizes a piece of text into its word pieces.
+
+        This uses a greedy longest-match-first algorithm to perform tokenization
+        using the given vocabulary.
+
+        For example:
+          input = "unaffable"
+          output = ["un", "##aff", "##able"]
+
+        Args:
+          text: A single token or whitespace separated tokens. This should have
+            already been passed through `BasicTokenizer`.
+
+        Returns:
+          A list of wordpiece tokens.
+        """
+
+        output_tokens = []
+        for token in whitespace_tokenize(text):
+            chars = list(token)
+            if len(chars) > self.max_input_chars_per_word:
+                output_tokens.append(self.unk_token)
+                continue
+
+            is_bad = False
+            start = 0
+            sub_tokens = []
+            while start < len(chars):
+                end = len(chars)
+                cur_substr = None
+                while start < end:
+                    substr = "".join(chars[start:end])
+                    if start > 0:
+                        substr = "##" + substr
+                    if substr in self.vocab:
+                        cur_substr = substr
+                        break
+                    end -= 1
+                if cur_substr is None:
+                    is_bad = True
+                    break
+                sub_tokens.append(cur_substr)
+                start = end
+
+            if is_bad:
+                output_tokens.append(self.unk_token)
+            else:
+                output_tokens.extend(sub_tokens)
+        return output_tokens
+
+
+def _is_whitespace(char):
+    """Checks whether `chars` is a whitespace character."""
+    # \t, \n, and \r are technically contorl characters but we treat them
+    # as whitespace since they are generally considered as such.
+    if char == " " or char == "\t" or char == "\n" or char == "\r":
+        return True
+    cat = unicodedata.category(char)
+    if cat == "Zs":
+        return True
+    return False
+
+
+def _is_control(char):
+    """Checks whether `chars` is a control character."""
+    # These are technically control characters but we count them as whitespace
+    # characters.
+    if char == "\t" or char == "\n" or char == "\r":
+        return False
+    cat = unicodedata.category(char)
+    if cat.startswith("C"):
+        return True
+    return False
+
+
+def _is_punctuation(char):
+    """Checks whether `chars` is a punctuation character."""
+    cp = ord(char)
+    # We treat all non-letter/number ASCII as punctuation.
+    # Characters such as "^", "$", and "`" are not in the Unicode
+    # Punctuation class but we treat them as punctuation anyways, for
+    # consistency.
+    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
+            (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
+        return True
+    cat = unicodedata.category(char)
+    if cat.startswith("P"):
+        return True
+    return False
-- 
GitLab