diff --git a/tokenization/__init__.py b/tokenization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6b481dc7ec60f9aed1acae5963726187cce81bb --- /dev/null +++ b/tokenization/__init__.py @@ -0,0 +1,36 @@ +# -*- encoding: utf-8 -*- +''' +@File : __init__.py +@Time : 2021/10/06 17:58:04 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch + +def get_tokenizer(args=None): + if not hasattr(get_tokenizer, 'tokenizer'): + # the first time to load the tokenizer + if args.tokenizer_type == 'cogview': + from .cogview import UnifiedTokenizer + get_tokenizer.tokenizer = UnifiedTokenizer( + args.img_tokenizer_path, + device=torch.cuda.current_device() + ) + elif args.tokenizer_type == 'glm': + pass + else: + assert args.vocab_size > 0 + get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size) + return get_tokenizer.tokenizer + +class FakeTokenizer(object): + def __init__(self, num_tokens): + self.num_tokens = num_tokens + def __len__(self): + return self.num_tokens \ No newline at end of file diff --git a/tokenization/cogview/__init__.py b/tokenization/cogview/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4df817a4fc7d2d474303795a668bf6f7e9c67699 --- /dev/null +++ b/tokenization/cogview/__init__.py @@ -0,0 +1,16 @@ +# -*- encoding: utf-8 -*- +''' +@File : __init__.py +@Time : 2021/10/06 18:21:15 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +from .unified_tokenizer import UnifiedTokenizer +from .templates import * \ No newline at end of file diff --git a/tokenization/cogview/sp_tokenizer.py b/tokenization/cogview/sp_tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..01ac73b248d34642c043048476b597a6b26a6a6d --- /dev/null +++ b/tokenization/cogview/sp_tokenizer.py @@ -0,0 +1,150 @@ +""" +from https://github.com/openai/gpt-2/, changed for chinese +""" +import json +import os +import sentencepiece as spm + +""" +SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation +systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements +subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the +extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end +system that does not depend on language-specific pre/postprocessing. +https://github.com/google/sentencepiece + +pip install sentencepiece + +or git clone https://github.com/google/sentencepiece.git +python setup.py install + +""" +PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model" + + +def get_pairs(word): + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + self.max_len = 0 + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + return [self.encoder.get(token, 1) for token in self.tokenize(text)] + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + return text + + def tokenize(self, text): + bpe_tokens = [] + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) + return bpe_tokens + + def convert_tokens_to_ids(self, tokens): + return [self.encoder.get(token, 1) for token in tokens] + + +class Encoder_SP: + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + """ + text="...." + """ + return self.sp.EncodeAsIds(text) + + def decode(self, tokens): + """ + tokens=[x1,x2,...] + """ + text = [int(token) for token in tokens] + return self.sp.DecodeIds(text) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + +def get_encoder(encoder_file, bpe_file): + # 以下是为了åŒä¸€ä¸ªå‡½æ•°å…¥å…¼å®¹sentencepiece + filepath, filename = os.path.split(encoder_file) + shotname, extension = os.path.splitext(filename) + + if (".model" == extension) and (bpe_file == ""): + return Encoder_SP(encoder_file) + else: + with open(encoder_file, 'r', encoding="utf-8") as f: + encoder = json.load(f) + with open(bpe_file, 'r', encoding="utf-8") as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] + return Encoder( + encoder=encoder, + bpe_merges=bpe_merges, + ) + + +def from_pretrained(): + return get_encoder(PRETRAINED_MODEL_FILE, "") \ No newline at end of file diff --git a/tokenization/cogview/templates.py b/tokenization/cogview/templates.py new file mode 100755 index 0000000000000000000000000000000000000000..4d665785dcdb0d011abf490fbb8725f76f6467ac --- /dev/null +++ b/tokenization/cogview/templates.py @@ -0,0 +1,77 @@ +# -*- encoding: utf-8 -*- +''' +@File : templates.py +@Time : 2021/01/11 22:28:57 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from tqdm import tqdm + +import numpy as np +import torch +import torch.nn.functional as F + +def concat_codes(*codes): + is_numpy = is_tensor = False + for code in codes: + if isinstance(code, np.ndarray): + is_numpy = True + if isinstance(code, torch.Tensor): + is_tensor = True + device = code.device + if is_tensor: + return torch.cat( + [ + torch.tensor(code, device=device) + for code in codes + ] + ) + elif is_numpy: + return np.concatenate( + [ + np.array(code) + for code in codes + ], + axis=0 + ) + else: + ret = [] + for code in codes: + ret = ret + code + return ret + +def TextCodeTemplate(text, code, tokenizer): + if isinstance(text, str): + text_ids = [tokenizer['[ROI1]']] + tokenizer(text) + else: + text_ids = np.concatenate( + ( + np.array([tokenizer['[ROI1]']]), + text, + ), + axis=0 + ) + code = tokenizer.wrap_code(code) + return concat_codes(text_ids, code) + +def Code2CodeTemplate(text, code0, code1, tokenizer): + text_ids = tokenizer.parse_query(text) if isinstance(text, str) else text + code0 = tokenizer.wrap_code(code0) + code1 = tokenizer.wrap_code(code1, idx=2) + return concat_codes(text_ids, code0, code1) + +def PureTextTemplate(text, tokenizer): + return tokenizer(text) + [tokenizer['[SEP]']] + + + + + + + diff --git a/tokenization/cogview/unified_tokenizer.py b/tokenization/cogview/unified_tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..bf66d9a7cb4dbb0e43582ea81b341a001d7651f2 --- /dev/null +++ b/tokenization/cogview/unified_tokenizer.py @@ -0,0 +1,193 @@ +# -*- encoding: utf-8 -*- +''' +@File : unified_tokenizer.py +@Time : 2021/01/11 16:36:33 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from tqdm import tqdm + +import numpy as np +import torch +import torch.nn.functional as F + +from .sp_tokenizer import from_pretrained +from .vqvae_tokenizer import VQVAETokenizer, sqrt_int + +class UnifiedTokenizer(object): + def __init__(self, img_tokenizer_path, device): + self.device = device + self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device) + self.txt_tokenizer = from_pretrained() + self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens + self.raw_command_tokens = [ + ('[PAD]', 0), + ('[BOI1]', 1), # Begin + ('[BOI2]', 2), + ('[BOI3]', 3), + ('[EOI1]', 4), # End + ('[EOI2]', 5), + ('[EOI3]', 6), + ('[ROI1]', 7), # Reference + ('[ROI2]', 8), + ('[ROI3]', 9), + ('[SEP]', 10), + ('[MASK]', 11), + ('[CLS]', 12), + ('[ENC]', 13), + ('[TINY]', 14), # 8 * 8 + ('[SMALL]', 15), # 16 * 16 + ('[BASE]', 16), # 32 * 32 + ('[BIG]', 17), # 64 * 64 + ('[POS0]', 18), # 58210 + ('[POS1]', 19), + ('[POS2]', 20), + ('[POS3]', 21), + ('[POS4]', 22), + ('[POS5]', 23), + ('[POS6]', 24), + ('[POS7]', 25), + ('[POS8]', 26) + # Please leave the ``size tokens'' at the back of command tokens + ] + self.command_tokens = { + k: v + self.num_tokens + for k, v in self.raw_command_tokens + } + self.num_tokens += len(self.raw_command_tokens) + + def __getitem__(self, command_token): + return self.command_tokens[command_token] + + def __len__(self): + """total number of tokens""" + return self.num_tokens + + def __call__(self, inputs, process_fn=None): + """run preprocessing and encode inputs as Ids + CANNOT contain command tokens""" + if isinstance(inputs, torch.Tensor): # image + if len(inputs.shape) == 3: + inputs = inputs.unsqueeze(0) + return self.img_tokenizer.EncodeAsIds(inputs) + return self.EncodeAsIds(inputs, process_fn=process_fn) + + def EncodeAsIds(self, text, process_fn=None): + processed_text = text + if process_fn is not None: + processed_text = process_fn(processed_text) + ids = self.txt_tokenizer.encode(processed_text) + return [x + self.img_tokenizer.num_tokens for x in ids] + + def DecodeIds(self, ids): + ret, img_buffer, txt_buffer, ret_imgs = [], [], [], [] + try: + for x in ids: + if self.num_tokens - len(self.raw_command_tokens) <= x: + # command tokens + token = self.raw_command_tokens[x - (self.num_tokens - len(self.raw_command_tokens))][0] + if token.startswith('[EOI') and len(img_buffer) > 0: + # dump image + ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer)) + img_buffer = [] + if len(txt_buffer) > 0: + # dump text + ret.append(self.txt_tokenizer.decode(txt_buffer)) + txt_buffer = [] + ret.append(token) + elif x < self.img_tokenizer.num_tokens: + img_buffer.append(x) + else: + txt_buffer.append(x - self.img_tokenizer.num_tokens) + + if len(img_buffer) > 0: + # dump image + ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer)) + img_buffer = [] + if len(txt_buffer) > 0: + # dump text + ret.append(self.txt_tokenizer.decode(txt_buffer)) + txt_buffer = [] + except ValueError: + print('Value error in tokenization, skipping...') + return ret, ret_imgs + + def wrap_code(self, code, idx=1): + s = sqrt_int(len(code)) + prefix = {8:'[TINY]', 16:'[SMALL]', 32:'[BASE]', 64:'[BIG]'}[s] + boi = {1:'[BOI1]', 2: '[BOI2]', 3:'[BOI3]'}[idx] + eoi = {1:'[EOI1]', 2: '[EOI2]', 3:'[EOI3]'}[idx] + + if isinstance(code, list): + return [self.command_tokens[prefix], self.command_tokens[boi]] + \ + code + [self.command_tokens[eoi]] + elif isinstance(code, np.ndarray): + return np.concatenate( + ( + np.array([self.command_tokens[prefix], self.command_tokens[boi]]), + code, + np.array([self.command_tokens[eoi]]) + ), + axis=0 + ) + elif isinstance(code, torch.Tensor): + return torch.cat( + ( + torch.tensor([self.command_tokens[prefix], self.command_tokens[boi]]), + code, + np.array([self.command_tokens[eoi]]) + ) + ) + else: + raise ValueError('') + + def parse_query(self, query, img_size=256): + text_buffer = [] + ret = [] + for part in query.split(' '): + if part in self.command_tokens: + if len(text_buffer) > 0: + # dump text ids + ret.extend(self.EncodeAsIds(' '.join(text_buffer))) + text_buffer = [] + if part == '[MASK]': + ret.append(-1) + else: + ret.append(self.command_tokens[part]) + elif part.startswith('[MASK]*'): # special lang *N + c = int(part[7:]) + assert c > 0 + if len(text_buffer) > 0: + # dump text ids + ret.extend(self.EncodeAsIds(' '.join(text_buffer))) + text_buffer = [] + ret.extend([-1] * c) + elif part.startswith('[Image'): # [Image*N]path + c = part[6:] + assert len(c) > 0 + num_codes, img_path = c.split(']') + if num_codes == '': + num_codes = 1024 + else: + num_codes = int(num_codes) + + raw_img = self.img_tokenizer.read_img(img_path, img_size=img_size) + img_codes = self.img_tokenizer.EncodeAsIds(raw_img) # [1, 32*32] + img_codes[0, num_codes:] = -1 + img_codes = img_codes[0].tolist() + ret.extend(img_codes) + else: + text_buffer.append(part) + + if len(text_buffer) > 0: + # dump text ids + ret.extend(self.EncodeAsIds(' '.join(text_buffer))) + text_buffer = [] + return ret + diff --git a/tokenization/cogview/vqvae_tokenizer.py b/tokenization/cogview/vqvae_tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..56ee251126fdcb901d4b88cea114342b1dfccdb7 --- /dev/null +++ b/tokenization/cogview/vqvae_tokenizer.py @@ -0,0 +1,86 @@ +# -*- encoding: utf-8 -*- +''' +@File : vqvae_tokenizer.py +@Time : 2021/01/11 17:57:43 +@Author : Ming Ding +@Contact : dm18@mails.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +from tqdm import tqdm + +import numpy as np +import torch +import torch.nn.functional as F + + +from vqvae import new_model, img2code, code2img +from torchvision import transforms +from PIL import Image + +def is_exp2(x): + t = math.log2(x) + return abs(t - int(t)) < 1e-4 +def sqrt_int(x): + r = int(math.sqrt(x) + 1e-4) + assert r * r == x + return r + +class VQVAETokenizer(object): + def __init__(self, + model_path, + device='cuda' + ): + ckpt = torch.load(model_path, map_location=torch.device(device)) + + model = new_model() + + if list(ckpt.keys())[0].startswith('module.'): + ckpt = {k[7:]: v for k, v in ckpt.items()} + + model.load_state_dict(ckpt) + model = model.to(device) + model.eval() + + self.model = model + self.device = device + self.image_tokens = model.quantize_t.n_embed + self.num_tokens = model.quantize_t.n_embed + self.tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]) + + def __len__(self): + return self.num_tokens + + def EncodeAsIds(self, img, add_normalization=False): + assert len(img.shape) == 4 # [b, c, h, w] + if add_normalization: + img = self.tr_normalize(img) + return img2code(self.model, img) + + def DecodeIds(self, code, shape=None): + if shape is None: + if isinstance(code, list): + code = torch.tensor(code, device=self.device) + s = sqrt_int(len(code.view(-1))) + assert s * s == len(code.view(-1)) + shape = (1, s, s) + code = code.view(shape) + out = code2img(self.model, code) + return out + + def read_img(self, path, img_size=256): + tr = transforms.Compose([ + transforms.Resize(img_size), + transforms.CenterCrop(img_size), + transforms.ToTensor(), + ]) + img = tr(Image.open(path)) + if img.shape[0] == 4: + img = img[:-1] + img = self.tr_normalize(img) + img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w] + return img \ No newline at end of file