From ba522b5e0b4774e3818f4927318a9670ccd0369e Mon Sep 17 00:00:00 2001 From: Zhengxiao Du <zx-du20@mails.tsinghua.edu.cn> Date: Wed, 20 Oct 2021 19:36:39 +0800 Subject: [PATCH] Import GLM inference code --- arguments.py | 36 +- config/model_glm_roberta_large.sh | 10 + inference_glm.py | 299 ++++++ model/glm_model.py | 3 +- move_weights_glm.py | 38 + mpu/transformer.py | 4 +- scripts/generate_glm.sh | 29 + tokenization/__init__.py | 22 +- tokenization/file_utils.py | 250 +++++ tokenization/text/__init__.py | 1 + tokenization/text/sp_tokenizer.py | 150 +++ tokenization/text/tokenization.py | 1254 ++++++++++++++++++++++++ tokenization/text/tokenization_gpt2.py | 310 ++++++ tokenization/text/wordpiece.py | 390 ++++++++ 14 files changed, 2787 insertions(+), 9 deletions(-) create mode 100644 config/model_glm_roberta_large.sh create mode 100644 inference_glm.py create mode 100644 move_weights_glm.py create mode 100644 scripts/generate_glm.sh create mode 100644 tokenization/file_utils.py create mode 100644 tokenization/text/__init__.py create mode 100644 tokenization/text/sp_tokenizer.py create mode 100644 tokenization/text/tokenization.py create mode 100644 tokenization/text/tokenization_gpt2.py create mode 100644 tokenization/text/wordpiece.py diff --git a/arguments.py b/arguments.py index 026e406..dbdeea6 100755 --- a/arguments.py +++ b/arguments.py @@ -157,6 +157,7 @@ def add_text_generate_args(parser): group.add_argument("--temperature", type=float, default=1.0) group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--top_k", type=int, default=0) + group.add_argument("--num-beams", type=int, default=1) group.add_argument("--out-seq-length", type=int, default=256) group.add_argument('--input-source', type=str, default='interactive', help='what input mode to use, interactive or path') @@ -214,12 +215,44 @@ def add_tokenization_args(parser): group = parser.add_argument_group('Tokenization', 'tokenization configurations') group.add_argument('--tokenizer-type', type=str, default='fake', help='type name of tokenizer') - + group.add_argument('--tokenizer-model-type', type=str, + default=None, + help="Model type to use for sentencepiece tokenization \ + (one of ['bpe', 'char', 'unigram', 'word']) or \ + bert vocab to use for BertWordPieceTokenizer (one of \ + ['bert-large-uncased', 'bert-large-cased', etc.])") group.add_argument('--img-tokenizer-path', type=str, default=None, help='The checkpoint file path of image tokenizer.') return parser +def add_glm_args(parser): + """Arguments for GLM""" + group = parser.add_argument_group('GLM', 'GLM Configurations') + group.add_argument('--block-lm', action='store_true', help="whether use the BlockLM pre-training") + group.add_argument('--masked-lm', action='store_true', help='whether to use the mlm objective') + group.add_argument('--bert-prob', type=float, default=0.5) + group.add_argument('--gpt-infill-prob', type=float, default=0.5) + group.add_argument('--gpt-min-ratio', type=float, default=0.5) + group.add_argument('--gap-sentence-prob', type=float, default=0.0) + group.add_argument('--gap-sentence-ratio', type=float, default=0.15) + group.add_argument('--avg-block-length', type=int, default=3) + group.add_argument('--short-seq-prob', type=float, default=0.0) + group.add_argument('--single-span-prob', type=float, default=0.0) + group.add_argument('--task-mask', action='store_true', help="Use different mask for generation and blank filling") + group.add_argument('--no-shuffle-block', action='store_true', help="not shuffle the blocks when filling the blank") + group.add_argument('--no-block-position', action='store_true', + help='Use (rough) absolute positions instead of block positions') + group.add_argument('--sentinel-token', action='store_true', + help="Use sentinel (mask) tokens to replace 2d position encoding") + group.add_argument('--block-mask-prob', type=float, default=0.0) + group.add_argument('--context-mask-ratio', type=float, default=0.0) + group.add_argument('--random-position', action='store_true', + help="Use random start position to cover all the position embeddings") + group.add_argument('--cloze-eval', action='store_true', help='Evaluation dataset with cloze task') + return parser + + def get_args(args_list=None): """Parse all the args.""" @@ -232,6 +265,7 @@ def get_args(args_list=None): parser = add_tokenization_args(parser) parser = add_text_generate_args(parser) parser = add_generation_api_args(parser) + parser = add_glm_args(parser) # Include DeepSpeed configuration arguments parser = deepspeed.add_config_arguments(parser) diff --git a/config/model_glm_roberta_large.sh b/config/model_glm_roberta_large.sh new file mode 100644 index 0000000..6fb7216 --- /dev/null +++ b/config/model_glm_roberta_large.sh @@ -0,0 +1,10 @@ +MODEL_TYPE="blocklm-roberta-large" +MODEL_ARGS="--block-lm \ + --cloze-eval \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --max-sequence-length 513 \ + --tokenizer-model-type roberta \ + --tokenizer-type GPT2BPETokenizer \ + --load ${CHECKPOINT_PATH}/blocklm-roberta-large-blank" \ No newline at end of file diff --git a/inference_glm.py b/inference_glm.py new file mode 100644 index 0000000..b88c2a7 --- /dev/null +++ b/inference_glm.py @@ -0,0 +1,299 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview.py +@Time : 2021/10/09 19:41:58 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import random +import time +from datetime import datetime +import torch +import torch.nn.functional as F + +import mpu +from arguments import get_args +from model.glm_model import GLMModel +from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from tokenization import get_tokenizer +from generation.sampling_strategies import BaseStrategy +from generation.autoregressive_sampling import filling_sequence +from generation.utils import timed_name, save_multiple_images, generate_continually + + +def read_context(tokenizer, args, output=None): + terminate_runs, skip_run = 0, 0 + if mpu.get_model_parallel_rank() == 0: + while True: + raw_text = input("\nContext prompt (stop to exit) >>> ") + if not raw_text: + print('Prompt should not be empty!') + continue + if raw_text == "stop": + terminate_runs = 1 + break + generation_mask = '[gMASK]' if args.task_mask else '[MASK]' + if args.block_lm and 'MASK]' not in raw_text: + raw_text += ' ' + generation_mask + if output is not None: + output.write(raw_text) + context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization + if args.block_lm: + context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens + if not raw_text.endswith('MASK]'): + context_tokens = context_tokens + [tokenizer.get_command('eos').Id] + context_length = len(context_tokens) + + if context_length >= args.max_sequence_length: + print("\nContext length", context_length, + "\nPlease give smaller context than the window length!") + continue + break + else: + context_length = 0 + + terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) + torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + terminate_runs = terminate_runs_tensor[0].item() + + if terminate_runs == 1: + return terminate_runs, None, None, None + + context_length_tensor = torch.cuda.LongTensor([context_length]) + + torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + context_length = context_length_tensor[0].item() + if mpu.get_model_parallel_rank() == 0: + context_tokens_tensor = torch.cuda.LongTensor(context_tokens) + else: + context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) + torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group()) + if mpu.get_model_parallel_rank() != 0: + raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) + return terminate_runs, raw_text, context_tokens_tensor, context_length + + +def get_batch(context_tokens, device, args): + tokens = context_tokens + tokens = tokens.view(1, -1).contiguous() + tokens = tokens.to(device) + + # Get the masks and postition ids. + if args.block_lm: + attention_mask = torch.ones(1, 1, tokens.size(1), tokens.size(1), device=device, dtype=torch.long) + if args.fp16: + attention_mask = attention_mask.half() + position_ids = torch.arange(tokens.size(1), device=device, dtype=torch.long) + if not args.no_block_position: + block_position_ids = torch.zeros(tokens.size(1), device=device, dtype=torch.long) + position_ids = torch.stack((position_ids, block_position_ids), dim=0) + position_ids = position_ids.unsqueeze(0) + else: + raise NotImplementedError + + return tokens, attention_mask, position_ids + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + # This function has been mostly taken from huggingface conversational ai code at + # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # convert to 1D + logits = logits.view(logits.size()[1]).contiguous() + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + # going back to 2D + logits = logits.view(1, -1).contiguous() + + return logits + + +def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None): + if not args.block_lm: + context_tokens, attention_mask, position_ids = get_batch(context_tokens, device, args) + tokens = torch.empty((args.num_beams, 0), device=context_tokens.device, dtype=torch.long) + else: + tokens = context_tokens.new_full((1, 1), tokenizer.get_command('sop').Id) + counter = 0 + if mems is None: + mems = [] + if end_tokens is None: + end_tokens = [args.eod_token] + if args.num_beams > 1: + beam_scorer = BeamSearchScorer( + batch_size=1, + max_length=args.out_seq_length, + num_beams=args.num_beams, + device=context_tokens.device, + length_penalty=args.length_penalty, + do_early_stopping=False, + ) + beam_scores = torch.zeros(1, dtype=torch.float, device=context_tokens.device) + last_beam_num = 1 + while counter < args.out_seq_length: + if counter == 0 and not args.block_lm: + next_token_logits, *mems = model(context_tokens, position_ids, attention_mask, *mems) + else: + if args.block_lm: + if args.no_block_position: + position_ids = context_tokens.new_full((last_beam_num, 1), context_length + counter) + else: + position_ids = context_tokens.new_ones(last_beam_num, 2, 1) + position_ids[:, 0] = context_length + position_ids[:, 1] = counter + 1 + attention_mask = context_tokens.new_zeros([1], device=context_tokens.device, dtype=torch.long) + else: + position_ids = context_tokens.new_ones((last_beam_num, 1)) * (context_length + counter - 1) + attention_mask = context_tokens.new_ones(last_beam_num, 1, 1, args.mem_length + 1, + device=context_tokens.device, dtype=torch.float) + last_token = tokens[:, -1:] + next_token_logits, *mems = model(last_token, position_ids, attention_mask, *mems) + next_token_logits = next_token_logits[:, -1] + if args.num_beams > 1: + next_token_scores = F.log_softmax(next_token_logits, dim=-1) + next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(1, last_beam_num * vocab_size) + + probs = F.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=2 * args.num_beams) + next_token_scores = torch.gather(next_token_scores, -1, next_tokens) + next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) + next_tokens = torch.gather(next_tokens, -1, _indices) + + next_indices = next_tokens // vocab_size + next_tokens = next_tokens % vocab_size + # stateless + tokens = tokens.expand((args.num_beams, -1)) + beam_outputs = beam_scorer.process( + tokens, + next_token_scores, + next_tokens, + next_indices, + eos_token_id=end_tokens, + mems=mems + ) + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + beam_next_tokens = beam_next_tokens.unsqueeze(-1) + tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], dim=-1) + mems = [mem[beam_idx] for mem in mems] if mems else None + if beam_scorer.is_done: + break + last_beam_num = args.num_beams + else: + next_token_logits /= args.temperature + next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p) + log_probs = F.softmax(next_token_logits, dim=-1) + prev = torch.multinomial(log_probs, num_samples=1)[0] + is_end = prev.item() in end_tokens + if is_end: + break + prev = prev.view(1, 1) + tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) + counter += 1 + if not args.block_lm and mpu.get_model_parallel_rank() == 0 and counter % 16 == 0: + output_tokens_list = tokens.view(-1).contiguous() + decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) + if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end): + os.system('clear') + trim_decode_tokens = decode_tokens + print(trim_decode_tokens, flush=True) + if args.num_beams > 1: + tokens, mems = beam_scorer.finalize(tokens, beam_scores, next_tokens, next_indices, eos_token_id=args.eod_token, + mems=mems) + return torch.cat((context_tokens, tokens), dim=1), mems + + +def generate_samples(model, tokenizer, args, device): + model.eval() + output_path = "./samples" + if not os.path.exists(output_path): + os.makedirs(output_path) + output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt") + with torch.no_grad(), open(output_path, "w") as output: + while True: + torch.distributed.barrier(group=mpu.get_model_parallel_group()) + + terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output) + if terminate_runs == 1: + return + start_time = time.time() + if args.block_lm: + mems = [] + tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args) + mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK'] + mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens] + end_tokens = [tokenizer.get_command('eop').Id, args.eod_token] + mask_positions = [] + for token in mask_tokens: + mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist() + mask_positions.sort() + if args.no_block_position: + for mask_position in mask_positions: + position_ids[0, mask_position + 1:] += args.out_seq_length + _, *mems = model(tokens, position_ids, attention_mask, *mems) + for mask_position in mask_positions: + if args.no_block_position: + position = position_ids[0, mask_position].item() + else: + position = mask_position + tokens, mems = sample_sequence(model, tokenizer, tokens, position, + args, device, mems=mems, end_tokens=end_tokens) + else: + tokens, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device) + output_tokens_list = tokens.view(-1).contiguous() + if mpu.get_model_parallel_rank() == 0: + os.system('clear') + print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) + print("\nContext:", raw_text, flush=True) + decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist()) + trim_decode_tokens = decode_tokens + print("\nGLM:", trim_decode_tokens, flush=True) + output.write(trim_decode_tokens + "\n") + + torch.distributed.barrier(group=mpu.get_model_parallel_group()) + + +def main(args): + initialize_distributed(args) + tokenizer = prepare_tokenizer(args) + args.eod_token = tokenizer.get_command('eos').Id + # build model + model = GLMModel(args) + if args.fp16: + model = model.half() + model = model.to(args.device) + load_checkpoint(model, args) + set_random_seed(args.seed) + model.eval() + generate_samples(model, tokenizer, args, torch.cuda.current_device()) + + +if __name__ == "__main__": + args = get_args() + + with torch.no_grad(): + main(args) diff --git a/model/glm_model.py b/model/glm_model.py index 0edf9b9..96502d1 100644 --- a/model/glm_model.py +++ b/model/glm_model.py @@ -2,9 +2,10 @@ import torch import torch.nn as nn from .base_model import BaseModel +from .cached_autoregressive_model import CachedAutoregressiveModel -class GLMModel(BaseModel): +class GLMModel(CachedAutoregressiveModel): def __init__(self, args, transformer=None): super().__init__(args, transformer=transformer) self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size) diff --git a/move_weights_glm.py b/move_weights_glm.py new file mode 100644 index 0000000..6d3755d --- /dev/null +++ b/move_weights_glm.py @@ -0,0 +1,38 @@ +import sys +import os +import torch +import copy + +checkpoint = sys.argv[1] +target_path = sys.argv[2] + +assert os.path.isdir(checkpoint) +iteration_file = os.path.join(checkpoint, 'latest_checkpointed_iteration.txt') +if os.path.exists(iteration_file): + with open(iteration_file) as fin: + iteration = int(fin.read().strip()) + checkpoint = os.path.join(checkpoint, str(iteration)) +else: + iteration = None + +os.makedirs(target_path, exist_ok=True) +if iteration is not None: + with open(os.path.join(target_path, "latest"), "w") as output: + output.write(str(iteration)) + target_path = os.path.join(target_path, str(iteration)) + os.makedirs(target_path, exist_ok=True) + + +filenames = os.listdir(checkpoint) +filenames = [filename for filename in filenames if filename.startswith("mp_rank_")] +filenames = sorted(filenames, + key=lambda x: int(x.split('_')[2])) +filenames = [os.path.join(checkpoint, x) for x in filenames] + +for filename in filenames: + data = torch.load(filename) + state_dict = data['module'] + state_dict['transformer.word_embeddings.weight'] = state_dict['word_embeddings.weight'] + del state_dict['word_embeddings.weight'] + # print(f"Target path: {os.path.join(target_path, os.path.basename(filename))}") + torch.save(data, os.path.join(target_path, os.path.basename(filename))) diff --git a/mpu/transformer.py b/mpu/transformer.py index e04485a..9751672 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -320,8 +320,6 @@ class BaseTransformer(torch.nn.Module): # sanity check assert len(input_ids.shape) == 2 batch_size, query_length = input_ids.shape - assert len(position_ids.shape) <= 2 - assert position_ids.shape[-1] == query_length assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 @@ -334,6 +332,8 @@ class BaseTransformer(torch.nn.Module): if 'position_embedding_forward' in self.hooks: position_embeddings = self.hooks['position_embedding_forward'](position_ids, *other_tensors) else: + assert len(position_ids.shape) <= 2 + assert position_ids.shape[-1] == query_length position_embeddings = self.position_embeddings(position_ids) hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) diff --git a/scripts/generate_glm.sh b/scripts/generate_glm.sh new file mode 100644 index 0000000..af5a59a --- /dev/null +++ b/scripts/generate_glm.sh @@ -0,0 +1,29 @@ +#!/bin/bash +CHECKPOINT_PATH=./checkpoints + +source $1 + +MPSIZE=1 +MAXSEQLEN=512 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) + +#SAMPLING ARGS +TEMP=0.9 +#If TOPK/TOPP are 0 it defaults to greedy sampling, top-k will also override top-p +TOPK=40 +TOPP=0 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +config_json="$script_dir/ds_config.json" + +python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTER_PORT inference_glm.py \ + --mode inference \ + --model-parallel-size $MPSIZE \ + $MODEL_ARGS \ + --fp16 \ + --out-seq-length $MAXSEQLEN \ + --temperature $TEMP \ + --top_k $TOPK \ + --top_p $TOPP diff --git a/tokenization/__init__.py b/tokenization/__init__.py index a6b481d..f465ec4 100644 --- a/tokenization/__init__.py +++ b/tokenization/__init__.py @@ -13,24 +13,36 @@ import math import random import torch + def get_tokenizer(args=None): + kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask, + "add_decoder_mask": args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0} if not hasattr(get_tokenizer, 'tokenizer'): # the first time to load the tokenizer if args.tokenizer_type == 'cogview': from .cogview import UnifiedTokenizer get_tokenizer.tokenizer = UnifiedTokenizer( - args.img_tokenizer_path, + args.img_tokenizer_path, device=torch.cuda.current_device() - ) - elif args.tokenizer_type == 'glm': - pass + ) + elif args.tokenizer_type == "BertWordPieceTokenizer": + from .text import BertWordPieceTokenizer + get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs) + elif args.tokenizer_type == "GPT2BPETokenizer": + from .text import GPT2BPETokenizer + get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) + elif args.tokenizer_type == "ChineseSPTokenizer": + from .text import ChineseSPTokenizer + get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs) else: assert args.vocab_size > 0 get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size) return get_tokenizer.tokenizer + class FakeTokenizer(object): def __init__(self, num_tokens): self.num_tokens = num_tokens + def __len__(self): - return self.num_tokens \ No newline at end of file + return self.num_tokens diff --git a/tokenization/file_utils.py b/tokenization/file_utils.py new file mode 100644 index 0000000..e4be142 --- /dev/null +++ b/tokenization/file_utils.py @@ -0,0 +1,250 @@ +# This file is provided as is from: +# https://github.com/huggingface/pytorch-pretrained-BERT +# Please refer to their repository for copyright. + +""" +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import (absolute_import, division, print_function, unicode_literals) + +import json +import logging +import os +import shutil +import tempfile +from functools import wraps +from hashlib import sha256 +import sys +from io import open + +import boto3 +import requests +from botocore.exceptions import ClientError +from tqdm import tqdm + +from urllib.parse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def filename_to_url(filename, cache_dir=None): + """ + Return the url and etag (which may be ``None``) stored for `filename`. + Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + cache_path = os.path.join(cache_dir, filename) + if not os.path.exists(cache_path): + raise EnvironmentError("file {} not found".format(cache_path)) + + meta_path = cache_path + '.json' + if not os.path.exists(meta_path): + raise EnvironmentError("file {} not found".format(meta_path)) + + with open(meta_path, encoding="utf-8") as meta_file: + metadata = json.load(meta_file) + url = metadata['url'] + etag = metadata['etag'] + + return url, etag + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + raise IOError("HEAD request failed for url {} with status code {}" + .format(url, response.status_code)) + etag = response.headers.get("ETag") + + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before closing it, so flush to avoid truncation + temp_file.flush() + # shutil.copyfileobj() starts at the current position, so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w', encoding="utf-8") as meta_file: + json.dump(meta, meta_file) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path + + +def read_set_from_file(filename): + ''' + Extract a de-duped collection (set) of text from a file. + Expected file format is one item per line. + ''' + collection = set() + with open(filename, 'r', encoding='utf-8') as file_: + for line in file_: + collection.add(line.rstrip()) + return collection + + +def get_file_extension(path, dot=True, lower=True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext diff --git a/tokenization/text/__init__.py b/tokenization/text/__init__.py new file mode 100644 index 0000000..d191be6 --- /dev/null +++ b/tokenization/text/__init__.py @@ -0,0 +1 @@ +from .tokenization import BertWordPieceTokenizer, GPT2BPETokenizer, ChineseSPTokenizer diff --git a/tokenization/text/sp_tokenizer.py b/tokenization/text/sp_tokenizer.py new file mode 100644 index 0000000..7b6430e --- /dev/null +++ b/tokenization/text/sp_tokenizer.py @@ -0,0 +1,150 @@ +""" +from https://github.com/openai/gpt-2/, changed for chinese +""" +import json +import os +import sentencepiece as spm + +""" +SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation +systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements +subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the +extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end +system that does not depend on language-specific pre/postprocessing. +https://github.com/google/sentencepiece + +pip install sentencepiece + +or git clone https://github.com/google/sentencepiece.git +python setup.py install + +""" +PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model" + + +def get_pairs(word): + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.max_len = 0 + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + return [self.encoder.get(token, 1) for token in self.tokenize(text)] + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + return text + + def tokenize(self, text): + bpe_tokens = [] + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + return [self.encoder.get(token, 1) for token in tokens] + + +class Encoder_SP: + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + + def encode(self, text): + """ + text="...." + """ + return self.sp.EncodeAsIds(text) + + def decode(self, tokens): + """ + tokens=[x1,x2,...] + """ + text = [int(token) for token in tokens] + # print(text) + return self.sp.DecodeIds(text) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + +def get_encoder(encoder_file, bpe_file): + # 以下是为了同一个函数入兼容sentencepiece + filepath, filename = os.path.split(encoder_file) + shotname, extension = os.path.splitext(filename) + + if (".model" == extension) and (bpe_file == ""): + return Encoder_SP(encoder_file) + else: + with open(encoder_file, 'r', encoding="utf-8") as f: + encoder = json.load(f) + with open(bpe_file, 'r', encoding="utf-8") as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) + + +def from_pretrained(): + return get_encoder(PRETRAINED_MODEL_FILE, "") diff --git a/tokenization/text/tokenization.py b/tokenization/text/tokenization.py new file mode 100644 index 0000000..51d2009 --- /dev/null +++ b/tokenization/text/tokenization.py @@ -0,0 +1,1254 @@ +# coding=utf-8 +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for using and training tokenizers (char, wordpiece, sentencepiece)""" +from collections import namedtuple +import random +import os +import csv +import torch +import itertools + +import nltk +from nltk import tokenize as nltk_tokenize +import sentencepiece as spm + +from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP + +from .tokenization_gpt2 import GPT2Tokenizer +from . import sp_tokenizer +import regex as re + + +def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type=None, pad_token=0, + character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs): + """ + Helper function to instantiate a tokenizer given common combinations of options. + """ + tokenizer_class = tokenizer_type + if isinstance(tokenizer_class, str): + tokenizer_class = eval(tokenizer_class) + if tokenizer_class is BertWordPieceTokenizer: + return BertWordPieceTokenizer(model_type, **kwargs) + elif tokenizer_class is GPT2BPETokenizer: + if model_type is None: + model_type = 'gpt2' + return GPT2BPETokenizer(model_type, **kwargs) + elif tokenizer_class is ChineseSPTokenizer: + return ChineseSPTokenizer(**kwargs) + text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type, + pad_token=pad_token, character_coverage=character_coverage) + return Tokenizer(text_tokenizer, command_tokens, type_tokens) + + +class Tokenization(object): + """ + Tokenization object to hold tokenization, (processed text),and original + text. Can hold tokenization as Ids or tokens. + + It also holds command tokens (pad, unk, etc.) for the tokenization. + This allows functions to pad/operate on tokenizations without having + access to the full tokenizer, just the tokenization. + + Several standard array operations are implemented (insert, append, extend). + """ + + def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True): + self.tokenization = tokenization + self.text = text + if self.text is None: + self.text = self.tokenization + self.original_text = original_text + if self.original_text is None: + self.original_text = self.text + self.command_tokens = command_tokens + self.asIds = asIds + self.parse_command_tokens() + + def set_command_tokens(self, command_tokens): + self.command_tokens = command_tokens + return self.parse_command_tokens() + + def parse_command_tokens(self): + if self.command_tokens is None: + return + for command_token in self.command_tokens: + if self.asIds: + setattr(self, command_token.name, command_token.Id) + else: + setattr(self, command_token.name, command_token.token) + + def __getitem__(self, index): + return self.tokenization[index] + + def __len__(self): + return len(self.tokenization) + + def insert(self, idx, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.insert(idx, other.Id) + if idx == 0: + self.text = other.token + self.text + self.original_text = other.token + self.original_text + elif idx == len(self.tokenization) - 1: + self.text += other.token + self.original_text += other.token + elif isinstance(other, Tokenization): + self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:] + else: + self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:] + + def append(self, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.append(other.Id) + self.text += other.token + self.original_text += other.token + elif isinstance(other, Tokenization): + self.tokenization.extend(other.tokenization) + self.text += other.text + self.original_text += other.original_text + else: + self.tokenization.append(other) + return self + + def extend(self, other): + if isinstance(other, (CommandToken, TypeToken)): + self.tokenization.append(other.Id) + self.text += other.token + self.original_text += other.token + elif isinstance(other, list) and isinstance(other[0], (CommandToken, TypeToken)): + self.tokenization.extend([o.Id for o in other]) + self.text += [o.token for o in other] + self.original_text += [o.token for o in other] + elif isinstance(other, Tokenization): + self.tokenization.extend(other.tokenization) + self.text += other.text + self.original_text += other.original_text + else: + self.tokenization.extend(other) + return self + + +"""define some default command tokens for the tokenizer to use""" +token_format = "<{0}>" + +COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id')) + + +def prep_command_tokens(tokenlist, token_format=token_format): + return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist] + + +class CommandToken(object): + def __init__(self, name, token, Id, lstrip=False, rstrip=False): + self.name = name + self.token = token + self.Id = Id + self.lstrip = lstrip + self.rstrip = rstrip + + def __str__(self): + return str(COMMAND_TUPLE(self.name, self.token, self.Id)) + + +DEFAULT_COMMAND_TOKENS = [ + ('pad', 0), + ('eos', 1), + ('bos', 2), + ('unk', 3), + ('sep', 4), + ('L2R', 5), + ('ENC', 6), + ('MASK', 7), +] +DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS) + +"""define some default type tokens for bert training""" + +TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id')) + + +def prep_type_tokens(tokenlist, token_format=token_format): + return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist] + + +class TypeToken(object): + def __init__(self, name, token, Id): + self.name = name + self.token = token + self.Id = Id + + def __str__(self): + return str(TYPE_TUPLE(self.name, self.token, self.Id)) + + +DEFAULT_TYPE_TOKENS = [ + ('function', 0), + ('command', 1), + ('str0', 2), + ('str1', 3), + ('str2', 4), + ('embedding0', 5), + ('embedding1', 6), + ('embedding2', 7), + ('arg0', 8), + ('arg1', 9), + ('arg2', 10), +] +DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS) + + +class Tokenizer(object): + """ + Tokenizer object that handles text tokenization, command tokens, and type tokens. + + Command tokens and text tokens are stored together in one mapping of size + `len(text_tokenizer)+len(command_tokens)`. Command tokens are stored as first + `len(command_tokens)` tokens. Token idx is stored at `idx+len(command_tokens)`. + + Token types are stored in a separate mapping of size `len(type_tokens)`. + """ + + def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None): + # set text tokenizer + self.text_tokenizer = text_tokenizer + if not hasattr(self, 'num_text_tokens'): + self.num_text_tokens = len(self.text_tokenizer) + + # set command tokens + if command_tokens is None: + command_tokens = DEFAULT_COMMAND_TOKENS + self._command_tokens = command_tokens + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = {tok.token: tok for tok in self._command_tokens} + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + if not hasattr(self, 'num_command_tokens'): + self.num_command_tokens = len(self._command_tokens) + if not hasattr(self, 'num_tokens'): + self.num_tokens = self.num_command_tokens + self.num_text_tokens + + # set type tokens + if type_tokens is None: + type_tokens = DEFAULT_TYPE_TOKENS + self.type_tokens = type_tokens + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + if not hasattr(self, 'num_type_tokens'): + self.num_type_tokens = len(self.type_tokens) + + # parse tokens and vocabs from tokenizer + self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens) + self._vocab = {t: Id for Id, t in self.command_id_map.items()} + self._vocab.update({t: Id + self.num_command_tokens for t, Id in self.text_tokenizer.vocab.items()}) + + self._text_tokens = list(self.text_tokenizer.tokens) + self._text_token_vocab = {t: Id + self.num_command_tokens for t, Id in self.text_tokenizer.vocab.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()} + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + def __call__(self, text, process_fn=None): + """run preprocessing and encode text as Ids""" + return self.EncodeAsIds(text, process_fn=process_fn) + + def __len__(self): + """total number of tokens""" + return self.num_tokens + + def get_command(self, name): + """get command token corresponding to `name`""" + return self.command_name_map[name] + + def get_type(self, name): + """get type token corresponding to `name`""" + return self.type_name_map[name] + + @property + def tokens(self): + """list (or iterable) of all tokens for tokenizer""" + return self._tokens + + @property + def vocab(self): + """dictionary mapping tokens to ids for tokenizer""" + return self._vocab + + @property + def token_types(self): + """list (or iterable) of all token types for tokenizer""" + return self._token_types + + @property + def token_type_vocab(self): + """dictionary mapping token types to ids for tokenizer""" + return self._token_type_vocab + + @property + def command_tokens(self): + """list (or iterable) of all command tokens for tokenizer""" + return self._command_token_tokens + + @property + def command_token_vocab(self): + """dictionary mapping command tokens to ids for tokenizer""" + return self._command_token_vocab + + @property + def text_tokens(self): + """list (or iterable) of text tokens for text tokenizer""" + return self._text_tokens + + @property + def text_token_vocab(self): + """dictionary mapping text tokens to ids for text tokenizer""" + return self._text_token_vocab + + def EncodeAsIds(self, text, process_fn=None): + """ + encode text using text tokenizer and shift Id values for command tokens + """ + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + + def split_on_token(tok_extended: CommandToken, text): + result = [] + tok = tok_extended.token + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + # CommandToken can control whitespace stripping around them. + # We use them for GPT2 and Roberta to have different behavior depending on the special token + # Cf. https://github.com/huggingface/transformers/pull/2778 + # and https://github.com/huggingface/transformers/issues/3788 + # Strip white spaces on the right + if tok_extended.rstrip and i > 0: + # A bit counter-intuitive but we strip the left of the string + # since tok_extended.rstrip means the special token is eating all white spaces on its right + sub_text = sub_text.lstrip() + # Strip white spaces on the left + if tok_extended.lstrip and i < len(split_text) - 1: + sub_text = sub_text.rstrip() # Opposite here + + if i == 0 and not sub_text: + result.append(tok) + elif i == len(split_text) - 1: + if sub_text: + result.append(sub_text) + else: + pass + else: + if sub_text: + result.append(sub_text) + result.append(tok) + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + if not tok_list: + return self.text_tokenizer.encode(text) + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self._command_token_tokens: + tokenized_text.extend(split_on_token(tok, sub_text)) + else: + tokenized_text.append(sub_text) + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + ( + self._encode(token) if token not in self._command_token_tokens else [ + self.command_token_map[token].Id] for token in tokenized_text + ) + ) + ) + + no_split_tokens = self._command_tokens + Ids = split_on_tokens(no_split_tokens, processed_text) + tokenization = Tokenization(Ids, processed_text, text) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def _encode(self, text): + raise NotImplementedError + + def EncodeAsTokens(self, text, process_fn=None): + """ + encode text as tokens using text tokenizer + """ + tokenization = self.text_tokenizer.EncodeAsTokens(text, process_fn=process_fn) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def IdToToken(self, Id, type_token=False): + """convert Id to token accounting for command and type tokens""" + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id < self.num_command_tokens: + return self.command_id_map[Id].token + return self.text_tokenizer.IdToToken(Id - self.num_command_tokens) + + def TokenToId(self, token, type_token=False): + """convert token to Id accounting for command and type tokens""" + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + if token in self.command_token_map: + return self.command_token_map[token].Id + return self.text_tokenizer.TokenToId(token) + self.num_command_tokens + + def DecodeIds(self, Ids, type_token=False): + """ + convert Ids to tokens accounting for command and type tokens, tokens + are joined and returned as a string. + """ + if type_token: + return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) + rtn_strs = [] + current_str = [] + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + for Id in Ids: + if isinstance(Id, CommandToken): + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + current_str = [] + rtn_strs.append(Id.token) + elif Id < self.num_command_tokens: + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + current_str = [] + rtn_strs.append(self.command_id_map[Id].token) + else: + current_str.append(Id - self.num_command_tokens) + if current_str != []: + rtn_strs.append(self.text_tokenizer.DecodeIds(current_str)) + return ' '.join(rtn_strs) + + def DecodeTokens(self, Tokens, type_token=False): + """ + convert tokens to a string accounting for command and type tokens. + """ + if type_token: + return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens) + rtn_strs = [] + current_str = [] + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + for t in Tokens: + if isinstance(t, CommandToken): + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + current_str = [] + rtn_strs.append(t.token) + elif t in self.command_token_map: + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + current_str = [] + rtn_strs.append(t) + else: + current_str.append(t) + if current_str != []: + rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) + return ' '.join(rtn_strs) + + +class TextTokenizer(object): + """ + Interface for text tokenizer + """ + + def __init__(self): + if not hasattr(self, 'num_text_tokens'): + self.num_text_tokens = 0 + if not hasattr(self, 'num_tokens'): + self.num_tokens = self.num_text_tokens + + def __call__(self, text, process_fn=None): + return self.EncodeAsIds(text, process_fn) + + def __len__(self): + return self.num_text_tokens + + @property + def tokens(self): + """list (or iterable) of text tokens for text tokenizer""" + raise NotImplementedError('TextTokenizer tokens property not implemented') + + @property + def vocab(self): + """dictionary mapping tokens to ids""" + raise NotImplementedError('TextTokenizer vocab property not implemented') + + @staticmethod + def exists(model_path): + """check if the filepath for a text tokenizer exists""" + raise NotImplementedError('TextTokenizer exists method not implemented') + + def Train(self, corpus): + """train a tokenizer on a data corpus and save model for future use""" + raise NotImplementedError('TextTokenizer Train not implemented') + + def EncodeAsIds(self, text, process_fn=None): + """ + Preprocess text and encode as ids. Return a tokenization object with + original text, processed text, and id tokenization. + """ + raise NotImplementedError('TextTokenizer EncodeAsIds not implemented') + + def EncodeAsTokens(self, text, process_fn=None): + """ + Preprocess text and encode as tokens. Return a tokenization object with + original text, processed text, and token tokenization. + """ + raise NotImplementedError('TextTokenizer EncodeAsTokens not implemented') + + def IdToToken(self, Id): + """Convert an Id to Token. Reverse lookup of self.vocab""" + raise NotImplementedError('TextTokenizer IdToToken not implemented') + + def TokenToId(self, token): + """Convert a Token to Id. Lookup of self.vocab""" + raise NotImplementedError('TextTokenizer TokenToId not implemented') + + def DecodeIds(self, Ids): + """Convert a list or tokenization object of Ids to a text string""" + raise NotImplementedError('TextTokenizer DecodeIds not implemented') + + def DecodeTokens(self, Tokens): + """Convert a list or tokenization object of tokens to a text string""" + raise NotImplementedError('TextTokenizer DecodeTokens not implemented') + + +class CharacterLevelTokenizer(TextTokenizer): + """ + Text tokenizer for ASCII-256 Character Level Tokenization. + """ + + def __init__(self, **kwargs): + self.num_text_tokens = 256 + super(CharacterLevelTokenizer, self).__init__() + self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)] + self._vocab = {t: i for i, t in enumerate(self._tokens)} + + def __len__(self): + return 256 + + @staticmethod + def exists(model_path): + return True + + def Train(self, corpus): + pass + + @property + def tokens(self): + return self._tokens + + @property + def vocab(self): + return self._vocab + + def EncodeAsIds(self, text, process_fn=None): + """convert text to ascii 256 Ids""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + processed_text = str(processed_text) + tokens = [self.TokenToId(c) for c in processed_text] + return Tokenization(tokens, processed_text, text) + + def EncodeAsTokens(self, text, process_fn=None): + """convert text to ascii 256 characters""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + processed_text = str(processed_text) + tokens = [c for c in processed_text] + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id): + """ascii index to character""" + return chr(Id) + + def TokenToId(self, token): + """ascii character to index""" + return ord(token) + + def DecodeIds(self, Ids): + """converts ascii ids to tokens before joining them into text""" + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return ''.join([self.IdToToken(tok) for tok in Ids]) + + def DecodeTokens(self, Tokens): + """just concatenates ascii tokens into text""" + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return ''.join(Tokens) + + +MAX_SENTENCEPIECE_SENTENCES = 100000000 + + +def get_corpus_freq(dataset, filepath, filetype='tsv'): + """ + Take corpus, split it into sentences, and extract word frequencies. + Write frequencies to `filepath` as a tsv. Only write the first + MAX_SENTENCEPIECE_SENTENCES most common words to the file. + """ + nltk.download('punkt', download_dir="./nltk") + if filetype == 'tsv': + delimiter = '\t' + else: + delimiter = ',' + + print("compute corpus frequency\n", flush=True) + + total_sentence_count = 0 + maxlen = 0 + freqs = {} + for entry in dataset: + if isinstance(entry, dict): + entry = entry['text'] + lines = entry.strip().split('\n') + for line in lines: + sentences = nltk_tokenize.sent_tokenize(line) + total_sentence_count += len(sentences) + for sentence in sentences: + maxlen = max(len(line), maxlen) + for word in sentence.split(): + if word not in freqs: + freqs[word] = 0 + freqs[word] += 1 + + print("length of freqs before truncating " + str(len(freqs)), flush=True) + print("file path for freq " + str(filepath), flush=True) + + freqs_sorted = {} + counter = 0 + for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True): + if counter >= MAX_SENTENCEPIECE_SENTENCES: + break + counter += 1 + freqs_sorted[word] = count + + print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True) + + with open(filepath, 'w') as f: + writer = csv.writer(f, delimiter=delimiter) + for k, v in freqs_sorted.items(): + writer.writerow([str(k), str(v)]) + + return total_sentence_count, maxlen + + +class SentencePieceTokenizer(TextTokenizer): + """Trains and uses sentencepiece for text tokenization""" + + def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0, + **kwargs): + self.character_coverage = character_coverage + self.model_type = model_type.lower() + self.spm_model = model_path + self.num_text_tokens = vocab_size + make_train = not SentencePieceTokenizer.exists(self.spm_model) + if make_train: + assert corpus is not None and self.num_text_tokens is not None + self.Train(corpus, self.num_text_tokens) + self._tokens = [] + self._vocab = {} + self.load_spm_model() + super(SentencePieceTokenizer, self).__init__() + + def __len__(self): + return self.num_text_tokens + + @property + def tokens(self): + return self._tokens + + @property + def vocab(self): + return self._vocab + + @staticmethod + def exists(model_path): + if model_path is None: + return False + # check if path exists + dne = not os.path.exists(model_path) + # check if path.model exists + if dne and not model_path.endswith('.model'): + dne = not os.path.exists(model_path + '.model') + return not dne + + def load_spm_model(self): + """load sentencepiece model and parse vocab""" + if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'): + self.spm_model = self.spm_model + '.model' + self.sp = spm.SentencePieceProcessor() + self.sp.Load(self.spm_model) + self.vocab_size = self.num_text_tokens = len(self.sp) + self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)] + self._vocab = {t: i for i, t in enumerate(self._tokens)} + + def Train(self, corpus, num_text_tokens): + """train sentencepiece model on corpus using word frequencies""" + self.num_text_tokens = num_text_tokens + use_model_path = self.spm_model + random_hash = str(random.randint(0, 2147483647)) + if use_model_path is None: + use_model_path = random_hash + if use_model_path.endswith('.model'): + use_model_path = use_model_path[:use_model_path.rfind('.model')] + input_path = use_model_path + '.tsv.' + random_hash + line_count, maxlenline = get_corpus_freq(corpus, input_path) + line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES) + print('line count used as input_sentence_size ', line_count, flush=True) + print('training sentencepiece model', flush=True) + train_string = '--input={file_path} --model_prefix={model_prefix} --vocab_size={vocab_size}' \ + + ' --model_type={model_type} --character_coverage={character_coverage} ' \ + + '--input_sentence_size={input_sentence_size} ' \ + + '--input_format=tsv' + train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, + vocab_size=num_text_tokens, + model_type=self.model_type, character_coverage=self.character_coverage, + input_sentence_size=int(line_count)) # , #)#, + print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True) + spm.SentencePieceTrainer.Train(train_string) + os.remove(input_path) + self.spm_model = use_model_path + '.model' + print('sentencepiece model written to ' + self.spm_model, flush=True) + + def EncodeAsIds(self, text, process_fn=None): + """convert text to sentencepiece Ids""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.sp.EncodeAsIds(processed_text) + return Tokenization(tokens, processed_text, text) + + def EncodeAsTokens(self, text, process_fn=None): + """convert text to sentencepiece tokens""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.sp.EncodeAsTokens(processed_text) + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id): + """convert Id to sentencpiece token""" + return self.sp.IdToPiece(Id) + + def TokenToId(self, token): + """convert sentencpiece token to Id""" + return self.sp.PieceToId(token) + + def DecodeIds(self, Ids): + """converts ids to a text string""" + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return self.sp.DecodeIds(Ids) + + def DecodeTokens(self, Tokens): + """converts sentencepiece tokens to a text string""" + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.sp.DecodeTokens(Tokens) + + +class BertWordPieceTokenizer(Tokenizer): + """ + Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization + in BERT training. Default to bert-large-uncased tokenizer. + """ + + def __init__(self, tokenizer_model_type=None, cache_dir=None, add_block_symbols=False, add_sentinel_token=0, + add_task_mask=False, add_decoder_mask=False, **kwargs): + # default to bert-large-uncased tokenizer + if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP: + tokenizer_model_type = 'bert-large-uncased' + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir) + do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type) + self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case, + cache_dir=cache_dir) + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + print('loaded', tokenizer_model_type) + # disable max len warnings by increasing max len + self.text_tokenizer.max_len = int(1e12) + + # set command tokens from wordpiece tokenizer values + self.num_command_tokens = 6 + self.num_tokens = len(self.text_tokenizer.vocab) + self.num_text_tokens = self.num_tokens - 5 + self.num_type_tokens = 2 + + self._command_tokens = [ + CommandToken('pad', '[PAD]', self.text_tokenizer.vocab['[PAD]']), + CommandToken('ENC', '[CLS]', self.text_tokenizer.vocab['[CLS]']), + CommandToken('MASK', '[MASK]', self.text_tokenizer.vocab['[MASK]']), + CommandToken('unk', '[UNK]', self.text_tokenizer.vocab['[UNK]']), + CommandToken('sep', '[SEP]', self.text_tokenizer.vocab['[SEP]']), + CommandToken('eos', '[PAD]', self.text_tokenizer.vocab['[PAD]']), + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_tokens), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_task_mask: + self._command_tokens.extend([ + CommandToken('gMASK', '[gMASK]', self.num_tokens), + CommandToken('sMASK', '[sMASK]', self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_decoder_mask: + self._command_tokens.extend([ + CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens) + ]) + self.num_tokens += 1 + self.num_command_tokens += 1 + if add_sentinel_token > 0: + for i in range(1, add_sentinel_token): + self._command_tokens.extend([CommandToken(f'MASK{i}', f'[MASK{i}]', self.num_tokens), + CommandToken(f'sop{i}', f'<|startofpiece{i}|>', self.num_tokens + 1)]) + self.num_tokens += 2 + self.num_command_tokens += 2 + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = {tok.token: tok for tok in self._command_tokens} + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + # set type tokens + self.type_tokens = [ + TypeToken('str0', '<str0>', 0), + TypeToken('str1', '<str1>', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + # parse tokens and vocabs from tokenizer + + self._tokens = list(self.text_tokenizer.vocab.keys()) + self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()} + + self._text_tokens = list(self._tokens) + self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()} + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + def _encode(self, text): + tokens = self.text_tokenizer.tokenize(text) + ids = self.text_tokenizer.convert_tokens_to_ids(tokens) + return ids + + def EncodeAsTokens(self, text, process_fn=None): + """convert wordpiece token to Id""" + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.text_tokenizer.tokenize(processed_text) + return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id, type_token=False): + """convert Id to sentencpiece token""" + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id in self.command_id_map: + return self.command_id_map[Id].token + return self.text_tokenizer.ids_to_tokens[Id] + + @staticmethod + def clean_up_tokenization(out_string: str) -> str: + """ + Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms. + + Args: + out_string (:obj:`str`): The text to clean up. + + Returns: + :obj:`str`: The cleaned-up string. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + def TokenToId(self, token, type_token=False): + """convert sentencpiece token to Id""" + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.vocab[token] + + def DecodeIds(self, Ids, type_token=False): + """converts ids to wordpiece tokens and joins them as a text string""" + if type_token: + return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + Tokens = [] + for Id in Ids: + if Id in self.command_id_map: + Tokens.append(self.command_id_map[Id].token) + elif Id in self.text_tokenizer.ids_to_tokens: + Tokens.append(self.text_tokenizer.ids_to_tokens[Id]) + new_tokens = [] + for token in Tokens: + if token.startswith('##') and len(new_tokens) > 0: + new_tokens[-1] += token[2:] + else: + new_tokens.append(token) + output = ' '.join(new_tokens) + output = self.clean_up_tokenization(output) + return output + + def DecodeTokens(self, Tokens, type_token=False): + """converts wordpiece tokens to a text string""" + if type_token: + return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return ' '.join(Tokens) + + +class GPT2BPETokenizer(Tokenizer): + def __init__(self, model_type_or_path, cache_dir=None, add_block_symbols=False, add_task_mask=False, + add_decoder_mask=False, **kwargs): + self.text_tokenizer = GPT2Tokenizer.from_pretrained(model_type_or_path, + cache_dir=cache_dir) + + # disable max len warnings by increasing max len + self.text_tokenizer.max_len = int(1e12) + self.num_tokens = len(self.text_tokenizer.encoder) + self.num_type_tokens = 2 + if model_type_or_path.startswith('roberta'): + self.num_command_tokens = 6 + self.num_text_tokens = self.num_tokens - 3 + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['</s>']), + CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['</s>']), + CommandToken('sep', '[SEP]', self.text_tokenizer.encoder['<pad>']), + CommandToken('ENC', '[CLS]', self.text_tokenizer.encoder['<s>']), + CommandToken('MASK', '[MASK]', self.text_tokenizer.encoder['<mask>'], lstrip=True), + CommandToken('unk', '[UNK]', self.text_tokenizer.encoder['<unk>']) + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_tokens), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + else: + self.num_command_tokens = 2 + self.num_text_tokens = self.num_tokens - 1 + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']), + CommandToken('eos', '<|endoftext|>', self.text_tokenizer.encoder['<|endoftext|>']) + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_tokens), + CommandToken('eop', '<|endofpiece|>', self.num_tokens + 1), + CommandToken('ENC', '[CLS]', self.num_tokens + 2), + CommandToken('MASK', '[MASK]', self.num_tokens + 3, lstrip=True), + CommandToken('sep', '[SEP]', self.num_tokens + 4), + CommandToken('unk', '[UNK]', self.num_tokens + 5) + ]) + self.num_tokens += 6 + self.num_command_tokens += 6 + if add_block_symbols: + if add_task_mask: + self._command_tokens.extend([ + CommandToken('gMASK', '[gMASK]', self.num_tokens, lstrip=True), + CommandToken('sMASK', '[sMASK]', self.num_tokens + 1, lstrip=True) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + if add_decoder_mask: + self._command_tokens.extend([ + CommandToken('dBLOCK', '[dBLOCK]', self.num_tokens) + ]) + self.num_tokens += 1 + self.num_command_tokens += 1 + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = {tok.token: tok for tok in self._command_tokens} + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + self.type_tokens = [ + TypeToken('str0', '<str0>', 0), + TypeToken('str1', '<str1>', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + self._tokens = list(self.text_tokenizer.encoder.keys()) + self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()} + + self._text_tokens = list(self._tokens) + self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()} + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + for idx, tok in self.command_id_map.items(): + self.text_tokenizer.decoder[idx] = tok.token + + def EncodeAsIds(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + + def split_on_token(tok_extended: CommandToken, text): + result = [] + tok = tok_extended.token + split_text = text.split(tok) + for i, sub_text in enumerate(split_text): + # CommandToken can control whitespace stripping around them. + # We use them for GPT2 and Roberta to have different behavior depending on the special token + # Cf. https://github.com/huggingface/transformers/pull/2778 + # and https://github.com/huggingface/transformers/issues/3788 + # Strip white spaces on the right + if tok_extended.rstrip and i > 0: + # A bit counter-intuitive but we strip the left of the string + # since tok_extended.rstrip means the special token is eating all white spaces on its right + sub_text = sub_text.lstrip() + # Strip white spaces on the left + if tok_extended.lstrip and i < len(split_text) - 1: + sub_text = sub_text.rstrip() # Opposite here + + if i == 0 and not sub_text: + result.append(tok) + elif i == len(split_text) - 1: + if sub_text: + result.append(sub_text) + else: + pass + else: + if sub_text: + result.append(sub_text) + result.append(tok) + return result + + def split_on_tokens(tok_list, text): + if not text.strip(): + return [] + if not tok_list: + return self.text_tokenizer.encode(text) + + tokenized_text = [] + text_list = [text] + for tok in tok_list: + tokenized_text = [] + for sub_text in text_list: + if sub_text not in self._command_token_tokens: + tokenized_text.extend(split_on_token(tok, sub_text)) + else: + tokenized_text.append(sub_text) + text_list = tokenized_text + + return list( + itertools.chain.from_iterable( + ( + self.text_tokenizer.encode(token) if token not in self._command_token_tokens else [ + self.command_token_map[token].Id] for token in tokenized_text + ) + ) + ) + + no_split_tokens = self._command_tokens + Ids = split_on_tokens(no_split_tokens, processed_text) + tokenization = Tokenization(Ids, processed_text, text) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def _encode(self, text): + return self.text_tokenizer.encode(text) + + def EncodeAsTokens(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = [] + for token in re.findall(self.text_tokenizer.pat, processed_text): + token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8')) + tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' ')) + tokenization = Tokenization(tokens, processed_text, text, asIds=False) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + + def DecodeAsTokens(self, Ids): + return [self.IdToToken(x) for x in Ids] + + def IdToToken(self, Id, type_token=False): + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id in self.command_id_map: + return self.command_id_map[Id].token + return self.text_tokenizer.decoder[Id] + + def TokenToId(self, token, type_token=False): + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.encoder[token] + + def DecodeIds(self, Ids, type_token=False): + if type_token: + return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + return self.text_tokenizer.decode(Ids) + + def DecodeTokens(self, Tokens, type_token=False): + if type_token: + return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens]) + + +class ChineseSPTokenizer(Tokenizer): + def __init__(self, add_block_symbols=False, **kwargs): + self.text_tokenizer = sp_tokenizer.from_pretrained() + + self.num_command_tokens = 2 + self.num_text_tokens = self.text_tokenizer.sp.vocab_size() + self.num_tokens = self.num_text_tokens + 1 + self.num_type_tokens = 2 + + self._command_tokens = [ + CommandToken('pad', '<|endoftext|>', self.num_text_tokens), + CommandToken('eos', '<|endoftext|>', self.num_text_tokens), + ] + if add_block_symbols: + self._command_tokens.extend([ + CommandToken('sop', '<|startofpiece|>', self.num_text_tokens + 1), + CommandToken('eop', '<|endofpiece|>', self.num_text_tokens + 2) + ]) + self.num_tokens += 2 + self.num_command_tokens += 2 + self.command_name_map = {tok.name: tok for tok in self._command_tokens} + self.command_token_map = {tok.token: tok for tok in self._command_tokens} + self.command_id_map = {tok.Id: tok for tok in self._command_tokens} + + self.type_tokens = [ + TypeToken('str0', '<str0>', 0), + TypeToken('str1', '<str1>', 1), + ] + self.type_name_map = {tok.name: tok for tok in self.type_tokens} + self.type_token_map = {tok.token: tok for tok in self.type_tokens} + self.type_id_map = {tok.Id: tok for tok in self.type_tokens} + + # self._tokens = list(self.text_tokenizer.encoder.keys()) + # self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} + # + # self._text_tokens = list(self._tokens) + # self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} + + self._command_token_tokens = list(self.command_token_map.keys()) + self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()} + + self._token_types = list(self.type_token_map.keys()) + self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()} + + def _encode(self, text): + ids = self.text_tokenizer.encode(text) + return ids + + def EncodeAsTokens(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + tokens = self.text_tokenizer.tokenize(processed_text) + tokenization = Tokenization(tokens, processed_text, text, asIds=False) + tokenization.set_command_tokens(self._command_tokens) + return tokenization + # return Tokenization(tokens, processed_text, text, asIds=False) + + def IdToToken(self, Id, type_token=False): + if isinstance(Id, (TypeToken, CommandToken)): + return Id.token + if type_token: + return self.type_id_map[Id].token + if Id in self.command_id_map: + return self.command_id_map[Id].token + elif Id in self.type_id_map: + return self.type_id_map[Id].token + else: + return self.text_tokenizer.convert_id_to_token(Id) + + def TokenToId(self, token, type_token=False): + if isinstance(token, (TypeToken, CommandToken)): + return token.Id + if type_token: + return self.type_token_map[token].Id + return self.text_tokenizer.convert_token_to_id(token) + + def DecodeIds(self, Ids, type_token=False): + if type_token: + return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) + if isinstance(Ids, Tokenization): + Ids = Ids.tokenization + try: + first_eos = Ids.index(self.get_command('eos').Id) + eos_count = len(Ids) - first_eos + Ids = Ids[:first_eos] + except ValueError: + eos_count = 0 + return " ".join((self.text_tokenizer.decode(Ids), *(['<|endoftext|>'] * eos_count))) + + def DecodeTokens(self, Tokens, type_token=False): + if type_token: + return ' '.join(t.token if isinstance(t, TypeToken) else t for t in Tokens) + if isinstance(Tokens, Tokenization): + Tokens = Tokens.tokenization + return self.text_tokenizer.decode([self.TokenToId(tok) for tok in Tokens]) diff --git a/tokenization/text/tokenization_gpt2.py b/tokenization/text/tokenization_gpt2.py new file mode 100644 index 0000000..318d920 --- /dev/null +++ b/tokenization/text/tokenization_gpt2.py @@ -0,0 +1,310 @@ +# coding=utf-8 +# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for OpenAI GPT.""" +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import sys +import json +import logging +import os +import regex as re +from io import open + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func + +from ..file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-vocab.json", + "roberta": "pretrained/pytorch_pretrained_bert/roberta-vocab.json" +} +PRETRAINED_MERGES_ARCHIVE_MAP = { + 'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-merges.txt", + "roberta": "pretrained/pytorch_pretrained_bert/roberta-merges.txt" +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'gpt2': 1024, +} +VOCAB_NAME = 'vocab.json' +MERGES_NAME = 'merges.txt' +SPECIAL_TOKENS_NAME = 'special_tokens.txt' + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + _chr = unichr if sys.version_info[0] == 2 else chr + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [_chr(n) for n in cs] + return dict(zip(bs, cs)) + +def get_pairs(word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + +class GPT2Tokenizer(object): + """ + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE + """ + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] + special_tokens_file = None + else: + vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) + merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) + special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME) + if not os.path.exists(special_tokens_file): + special_tokens_file = None + else: + logger.info("loading special tokens file {}".format(special_tokens_file)) + # redirect to the cache, if necessary + # try: + # resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + # resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) + # except EnvironmentError: + # logger.error( + # "Model name '{}' was not found in model name list ({}). " + # "We assumed '{}' was a path or url but couldn't find files {} and {} " + # "at this path or url.".format( + # pretrained_model_name_or_path, + # ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + # pretrained_model_name_or_path, + # vocab_file, merges_file)) + # return None + # if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: + # logger.info("loading vocabulary file {}".format(vocab_file)) + # logger.info("loading merges file {}".format(merges_file)) + # else: + # logger.info("loading vocabulary file {} from cache at {}".format( + # vocab_file, resolved_vocab_file)) + # logger.info("loading merges file {} from cache at {}".format( + # merges_file, resolved_merges_file)) + resolved_vocab_file = vocab_file + resolved_merges_file = merges_file + logger.info("loading vocabulary file {}".format(vocab_file)) + logger.info("loading merges file {}".format(merges_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + if special_tokens_file and 'special_tokens' not in kwargs: + special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] + else: + special_tokens = kwargs.pop('special_tokens', []) + tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) + return tokenizer + + def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): + self.max_len = max_len if max_len is not None else int(1e12) + self.encoder = json.load(open(vocab_file)) + self.decoder = {v:k for k,v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + self.special_tokens = {} + self.special_tokens_decoder = {} + self.set_special_tokens(special_tokens) + + def __len__(self): + return len(self.encoder) + len(self.special_tokens) + + def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return + self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} + logger.info("Special tokens {}".format(self.special_tokens)) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def tokenize(self, text): + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + if sys.version_info[0] == 2: + token = ''.join(self.byte_encoder[ord(b)] for b in token) + else: + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + """ Converts a sequence of tokens into ids using the vocab. """ + ids = [] + if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.encoder.get(token, 0)) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this OpenAI GPT model ({} > {}). Running this" + " sequence through the model will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): + """Converts a sequence of ids in BPE tokens using the vocab.""" + tokens = [] + for i in ids: + if i in self.special_tokens_decoder: + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) + return tokens + + def encode(self, text): + return self.convert_tokens_to_ids(self.tokenize(text)) + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary and merge files to a directory.""" + if not os.path.isdir(vocab_path): + logger.error("Vocabulary path ({}) should be a directory".format(vocab_path)) + return + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + merge_file = os.path.join(vocab_path, MERGES_NAME) + special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) + + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) + + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 + + index = len(self.encoder) + with open(special_tokens_file, 'w', encoding='utf-8') as writer: + for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(special_tokens_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + + return vocab_file, merge_file, special_tokens_file diff --git a/tokenization/text/wordpiece.py b/tokenization/text/wordpiece.py new file mode 100644 index 0000000..e5ce2e2 --- /dev/null +++ b/tokenization/text/wordpiece.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import collections +import logging +import os +import unicodedata +from io import open + +from ..file_utils import cached_path + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "pretrained/pytorch_pretrained_bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "pretrained/pytorch_pretrained_bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "pretrained/pytorch_pretrained_bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "pretrained/pytorch_pretrained_bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", +} +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + index = 0 + with open(vocab_file, "r", encoding="utf-8") as reader: + while True: + token = reader.readline() + if not token: + break + token = token.strip() + vocab[token] = index + index += 1 + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input + Only has an effect when do_wordpiece_only=False + do_basic_tokenize: Whether to do basic tokenization before wordpiece. + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + never_split: List of tokens which will never be split during tokenization. + Only has an effect when do_wordpiece_only=False + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + if self.do_basic_tokenize: + split_tokens = [] + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified maximum " + " sequence length for this BERT model ({} > {}). Running this" + " sequence through BERT will result in indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: + # if we're using a pretrained model, ensure the tokenizer wont index sequences longer + # than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on November 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it doesn't + # matter since the English models were not trained on any Chinese data + # and generally don't have any Chinese data in them (there are Chinese + # characters in the vocabulary because Wikipedia does have some Chinese + # words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform tokenization + using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False -- GitLab