Skip to content
Snippets Groups Projects
Unverified Commit 52b96a84 authored by duzx16's avatar duzx16 Committed by GitHub
Browse files

Merge pull request #3 from THUDM/tokenizer

Refactor tokenization
parents 028cfe61 cb9ad4ce
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,7 @@ import torch ...@@ -15,6 +15,7 @@ import torch
from SwissArmyTransformer.training.utils import print_rank_0 from SwissArmyTransformer.training.utils import print_rank_0
def _export_vocab_size_to_args(args, original_num_tokens): def _export_vocab_size_to_args(args, original_num_tokens):
tokenizer = get_tokenizer(args) tokenizer = get_tokenizer(args)
num_tokens = original_num_tokens num_tokens = original_num_tokens
...@@ -32,6 +33,7 @@ def _export_vocab_size_to_args(args, original_num_tokens): ...@@ -32,6 +33,7 @@ def _export_vocab_size_to_args(args, original_num_tokens):
print_rank_0("prepare tokenizer done") print_rank_0("prepare tokenizer done")
return tokenizer return tokenizer
def get_tokenizer(args=None, outer_tokenizer=None): def get_tokenizer(args=None, outer_tokenizer=None):
''' '''
If you're using outer_tokenizer, call `get_tokenizer(args, outer_tokenizer)` If you're using outer_tokenizer, call `get_tokenizer(args, outer_tokenizer)`
...@@ -53,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): ...@@ -53,7 +55,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
) )
elif args.tokenizer_type.startswith('glm_'): elif args.tokenizer_type.startswith('glm_'):
kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask, kwargs = {"add_block_symbols": True, "add_task_mask": args.task_mask,
"add_decoder_mask": False} "add_decoder_mask": False}
if args.tokenizer_type == "glm_GPT2BPETokenizer": if args.tokenizer_type == "glm_GPT2BPETokenizer":
from .glm import GPT2BPETokenizer from .glm import GPT2BPETokenizer
get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs) get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
......
...@@ -3,6 +3,11 @@ from https://github.com/openai/gpt-2/, changed for chinese ...@@ -3,6 +3,11 @@ from https://github.com/openai/gpt-2/, changed for chinese
""" """
import json import json
import os import os
import csv
import nltk
import random
from nltk import tokenize as nltk_tokenize
import sentencepiece as spm import sentencepiece as spm
""" """
...@@ -22,129 +27,72 @@ python setup.py install ...@@ -22,129 +27,72 @@ python setup.py install
PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)), PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
'embed_assets', 'chinese_sentencepiece/cog-pretrain.model') 'embed_assets', 'chinese_sentencepiece/cog-pretrain.model')
def get_pairs(word): class SentencePieceTokenizer:
pairs = set() """Trains and uses sentencepiece for text tokenization"""
prev_char = word[0]
for char in word[1:]: def __init__(self, model_path=None, **kwargs):
pairs.add((prev_char, char)) self.spm_model = model_path
prev_char = char self._tokens = []
return pairs self._vocab = {}
self.sp, self.vocab_size = None, 0
self.load_spm_model()
class Encoder:
def __init__(self, encoder, bpe_merges): @classmethod
self.encoder = encoder def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
self.decoder = {v: k for k, v in self.encoder.items()} if pretrained_model_name_or_path in ['glm-large', 'glm-10b']:
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) return cls(model_path=PRETRAINED_MODEL_FILE)
self.cache = {} else:
self.max_len = 0 return cls(model_path=pretrained_model_name_or_path)
def bpe(self, token): def __len__(self):
if token in self.cache: return self.num_text_tokens
return self.cache[token]
word = tuple(token) def load_spm_model(self):
pairs = get_pairs(word) """load sentencepiece model and parse vocab"""
if not pairs: if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
return token self.spm_model = self.spm_model + '.model'
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
return [self.encoder.get(token, 1) for token in self.tokenize(text)]
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
return text
def tokenize(self, text):
bpe_tokens = []
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' '))
return bpe_tokens
def convert_tokens_to_ids(self, tokens):
return [self.encoder.get(token, 1) for token in tokens]
class Encoder_SP:
def __init__(self, model_path):
self.sp = spm.SentencePieceProcessor() self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_path) self.sp.Load(self.spm_model)
self.vocab_size = self.num_text_tokens = len(self.sp)
self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
self._vocab = {t: i for i, t in enumerate(self._tokens)}
@property
def tokens(self):
return self._tokens
@property
def vocab(self):
return self._vocab
@staticmethod
def exists(model_path):
if model_path is None:
return False
# check if path exists
dne = not os.path.exists(model_path)
# check if path.model exists
if dne and not model_path.endswith('.model'):
dne = not os.path.exists(model_path + '.model')
return not dne
def encode(self, text): def encode(self, text):
""" """convert text to sentencepiece Ids"""
text="...." tokens = self.sp.EncodeAsIds(text)
""" return tokens
return self.sp.EncodeAsIds(text)
def decode(self, tokens):
"""
tokens=[x1,x2,...]
"""
text = [int(token) for token in tokens]
# print(text)
return self.sp.DecodeIds(text)
def tokenize(self, text):
return self.sp.EncodeAsPieces(text)
def convert_tokens_to_ids(self, tokens):
return [self.sp.PieceToId(token) for token in tokens]
def convert_token_to_id(self, token):
return self.sp.PieceToId(token)
def convert_id_to_token(self, idx): def IdToToken(self, Id):
return self.sp.IdToPiece(idx) """convert Id to sentencpiece token"""
return self.sp.IdToPiece(Id)
def TokenToId(self, token):
"""convert sentencpiece token to Id"""
return self.sp.PieceToId(token)
def get_encoder(encoder_file, bpe_file): def decode(self, Ids):
# 以下是为了同一个函数入兼容sentencepiece """converts ids to a text string"""
filepath, filename = os.path.split(encoder_file) return self.sp.DecodeIds(Ids)
shotname, extension = os.path.splitext(filename)
if (".model" == extension) and (bpe_file == ""):
return Encoder_SP(encoder_file)
else:
with open(encoder_file, 'r', encoding="utf-8") as f:
encoder = json.load(f)
with open(bpe_file, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)
def from_pretrained(): def from_pretrained():
return get_encoder(PRETRAINED_MODEL_FILE, "") return SentencePieceTokenizer(model_path=PRETRAINED_MODEL_FILE)
\ No newline at end of file
This diff is collapsed.
...@@ -168,6 +168,14 @@ class GPT2Tokenizer(object): ...@@ -168,6 +168,14 @@ class GPT2Tokenizer(object):
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens) self.set_special_tokens(special_tokens)
@property
def tokens(self):
return self.decoder
@property
def vocab(self):
return self.encoder
def __len__(self): def __len__(self):
return len(self.encoder) + len(self.special_tokens) return len(self.encoder) + len(self.special_tokens)
...@@ -309,4 +317,4 @@ class GPT2Tokenizer(object): ...@@ -309,4 +317,4 @@ class GPT2Tokenizer(object):
writer.write(token + u'\n') writer.write(token + u'\n')
index += 1 index += 1
return vocab_file, merge_file, special_tokens_file return vocab_file, merge_file, special_tokens_file
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment