Skip to content
Snippets Groups Projects
Commit 049b6ffc authored by duzx16's avatar duzx16
Browse files

Add T5 Tokenizer

parent 3fc1af82
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,8 @@ def _export_vocab_size_to_args(args, original_num_tokens):
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
before, after - before, after))
args.vocab_size = after
if not args.vocab_size:
args.vocab_size = after
print_rank_0("prepare tokenizer done")
return tokenizer
......@@ -63,6 +64,10 @@ def get_tokenizer(args=None, outer_tokenizer=None):
elif args.tokenizer_type == "glm_ChineseSPTokenizer":
from .glm import ChineseSPTokenizer
get_tokenizer.tokenizer = ChineseSPTokenizer(args.tokenizer_model_type, **kwargs)
elif args.tokenizer_type.startswith('hf'):
from .hf_tokenizer import HFT5Tokenizer
if args.tokenizer_type == "hf_T5Tokenizer":
get_tokenizer.tokenizer = HFT5Tokenizer(args.tokenizer_model_type)
else:
assert args.vocab_size > 0
get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
......
......@@ -312,11 +312,11 @@ class Tokenizer(object):
tokenization.tokenization = [self.IdToToken(idx) for idx in tokenization.tokenization]
return tokenization
def IdToToken(self, Id):
def IdToToken(self, idx):
"""convert Id to token accounting for command tokens"""
if isinstance(Id, CommandToken):
return Id.token
return self.tokens[Id]
if isinstance(idx, CommandToken):
return idx.token
return self.tokens[idx]
def TokenToId(self, token):
"""convert token to Id accounting for command tokens"""
......@@ -324,16 +324,16 @@ class Tokenizer(object):
return token.Id
return self.vocab[token]
def DecodeIds(self, Ids):
def DecodeIds(self, ids):
"""
convert Ids to tokens accounting for command tokens, tokens
are joined and returned as a string.
"""
rtn_strs = []
current_str = []
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
for Id in Ids:
if isinstance(ids, Tokenization):
ids = ids.tokenization
for Id in ids:
if isinstance(Id, CommandToken):
rtn_strs.append(self._decode(current_str))
current_str = []
......@@ -353,11 +353,11 @@ class Tokenizer(object):
output = self.clean_up_tokenization(output)
return output
def DecodeTokens(self, Tokens):
def DecodeTokens(self, tokens):
"""
convert tokens to a string accounting for command and type tokens.
"""
Ids = [self.TokenToId(token) for token in Tokens]
Ids = [self.TokenToId(token) for token in tokens]
return self.DecodeIds(Ids)
......
from transformers import T5Tokenizer
from .glm.tokenization import Tokenization, CommandToken
class HFTokenizer:
def __init__(self, model_cls, model_type_or_path=None, cache_dir=None, command_tokens=None):
self.text_tokenizer = model_cls.from_pretrained(model_type_or_path, cache_dir=cache_dir)
self.num_tokens = len(self.text_tokenizer)
self._command_tokens = []
self.command_name_map = {}
self.command_token_map = {}
self.command_id_map = {}
@property
def command_tokens(self):
return self._command_tokens
@command_tokens.setter
def command_tokens(self, command_tokens):
self._command_tokens = command_tokens
self.command_name_map = {tok.name: tok for tok in self.command_tokens}
self.command_token_map = {tok.token: tok for tok in self.command_tokens}
self.command_id_map = {tok.Id: tok for tok in self.command_tokens}
def get_command(self, name):
"""get command token corresponding to `name`"""
return self.command_name_map[name]
def EncodeAsIds(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
ids = self.text_tokenizer.encode(processed_text, add_special_tokens=False)
tokenization = Tokenization(ids, processed_text, text)
return tokenization
def DecodeIds(self, ids):
if isinstance(ids, Tokenization):
ids = ids.tokenization
return self.text_tokenizer.decode(ids)
def DecodeTokens(self, tokens):
return self.text_tokenizer.convert_tokens_to_string(tokens)
def IdToToken(self, Id):
if isinstance(Id, CommandToken):
return Id.token
return self.text_tokenizer.convert_ids_to_tokens(Id)
def TokenToId(self, token):
if isinstance(token, CommandToken):
return token.Id
return self.text_tokenizer.convert_tokens_to_ids(token)
class HFT5Tokenizer(HFTokenizer):
def __init__(self, model_type_or_path=None, cache_dir=None):
super().__init__(T5Tokenizer, model_type_or_path=model_type_or_path, cache_dir=cache_dir)
command_tokens = [
CommandToken('eos', '</s>', self.TokenToId("</s>")),
CommandToken('pad', '<pad>', self.TokenToId("<pad>")),
]
for i in range(100):
command_tokens.append(CommandToken(f'MASK{i}', f'<extra_id_{i}>', self.TokenToId(f'<extra_id_{i}>')))
self.command_tokens = command_tokens
......@@ -10,6 +10,6 @@ MODEL_ARGS="--block-lm \
--max-sequence-length 513 \
--relative-attention-num-buckets 32 \
--layernorm-epsilon 1e-6 \
--tokenizer-model-type roberta \
--tokenizer-type glm_GPT2BPETokenizer \
--tokenizer-type hf_T5Tokenizer \
--tokenizer-model-type t5-large \
--load ${CHECKPOINT_PATH}/glm-large-en-blank"
\ No newline at end of file
......@@ -48,7 +48,7 @@ def get_masks_and_position_ids_glm(seq, mask_position, context_length):
def main(args):
args.do_train = False
initialize_distributed(args)
# tokenizer = get_tokenizer(args)
tokenizer = get_tokenizer(args)
# build model
model = T5Model(args)
if args.fp16:
......@@ -60,9 +60,16 @@ def main(args):
torch.load("/dataset/fd5061f6/yanan/huggingface_models/t5-large/model_states.pt")["module"])
from SwissArmyTransformer.model.encoder_decoder_model import EncoderFinalMixin
model.eval()
input_ids = torch.cuda.LongTensor([[37, 32099, 10681, 16, 32098, 2447, 1]])
decoder_input_ids = torch.cuda.LongTensor([[32099, 5295, 1782, 32098, 8, 32097, 1]])
output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
input_ids = tokenizer.EncodeAsIds("The <extra_id_0> walks in <extra_id_1> park").tokenization
input_ids = input_ids + [tokenizer.get_command("eos").Id]
input_ids = torch.cuda.LongTensor([input_ids])
# input_ids = torch.cuda.LongTensor([[37, 32099, 10681, 16, 32098, 2447, 1]])
decoder_input_ids = tokenizer.EncodeAsIds('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>').tokenization
decoder_input_ids = decoder_input_ids + [tokenizer.get_command("eos").Id]
decoder_input_ids = torch.cuda.LongTensor([decoder_input_ids])
# decoder_input_ids = torch.cuda.LongTensor([[32099, 5295, 1782, 32098, 8, 32097, 1]])
breakpoint()
output = model(enc_input_ids=input_ids, dec_input_ids=decoder_input_ids)
print(output)
end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
# define function for each query
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment