From 142df57c84a5fd1931a75421eb71fbdd9ffcd318 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Wed, 6 Oct 2021 13:36:00 +0000 Subject: [PATCH] tmp save 4 --- data_utils/__init__.py | 6 +- data_utils/configure_data.py | 70 ++++------- data_utils/datasets.py | 79 +----------- data_utils/samplers.py | 18 +-- data_utils/sp_tokenizer.py | 150 ---------------------- data_utils/templates.py | 83 ------------- data_utils/unified_tokenizer.py | 212 -------------------------------- data_utils/vqvae_tokenizer.py | 86 ------------- mpu/initialize.py | 2 +- pretrain_gpt2.py | 27 +++- training/deepspeed_training.py | 17 +-- 11 files changed, 71 insertions(+), 679 deletions(-) delete mode 100755 data_utils/sp_tokenizer.py delete mode 100755 data_utils/templates.py delete mode 100755 data_utils/unified_tokenizer.py delete mode 100755 data_utils/vqvae_tokenizer.py diff --git a/data_utils/__init__.py b/data_utils/__init__.py index 509841d..394d3ea 100755 --- a/data_utils/__init__.py +++ b/data_utils/__init__.py @@ -8,7 +8,5 @@ # here put the import lib -from .unified_tokenizer import get_tokenizer - -from .templates import * -from .configure_data import make_loaders, detect_new_datasets \ No newline at end of file +from .configure_data import make_loaders +from .datasets import * \ No newline at end of file diff --git a/data_utils/configure_data.py b/data_utils/configure_data.py index 08d6a59..6556673 100755 --- a/data_utils/configure_data.py +++ b/data_utils/configure_data.py @@ -11,16 +11,13 @@ import os import sys import math import random -from tqdm import tqdm import copy import numpy as np import torch -import torch.nn.functional as F from bisect import bisect_right +from functools import partial -from .unified_tokenizer import get_tokenizer -from .datasets import get_dataset_by_type from torch.utils import data from .samplers import DistributedBatchSampler @@ -36,46 +33,37 @@ def make_data_loader(dataset, batch_size, num_iters, args): sampler = torch.utils.data.SequentialSampler(dataset) drop_last = distributed # the GPUs in the same model parallel group receive the same data - if distributed: + if distributed: # TODO reformat this, but it is not urgent gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1) batch_sampler = DistributedBatchSampler(sampler, - batch_size, - drop_last, - rank, - world_size, - gradient_accumulation_steps=gradient_accumulation_steps) + batch_size, + drop_last, + rank, + world_size, + gradient_accumulation_steps=gradient_accumulation_steps) else: batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last) data_loader = torch.utils.data.DataLoader(dataset, - batch_sampler=batch_sampler, - num_workers=args.num_workers, - pin_memory=True) + batch_sampler=batch_sampler, + num_workers=args.num_workers, + pin_memory=True) return data_loader -def make_dataset(dataset_type, path, split, args, **kwargs): +def make_dataset_full(dataset_type, path, split, args, create_dataset_function, **kwargs): """function to create datasets+tokenizers for common options""" print('make dataset ...', path) if split is None: split = [1.] assert isinstance(path, list) - # TODO other dsclass, e.g. odps - # ds = [get_dataset_by_type(dataset_type, p, args) for p in path] - # dataset object can be copied N times + ds = [] for p in path: - d = get_dataset_by_type(dataset_type, p, args) - if p.find('t2i') >= 0: - ds.extend([d] * 4) - print(f'Enlarge {p} 4 times...') - elif p.find('i2t') >= 0: - ds.extend([d] * 2) - print(f'Enlarge {p} 2 times...') - else: - ds.append(d) + d = create_dataset_function(p, args) + ds.append(d) ds = RandomMappingDataset(ConcatDataset(ds)) @@ -84,8 +72,15 @@ def make_dataset(dataset_type, path, split, args, **kwargs): # FIXME this will merge valid set and train set. return ds -def make_loaders(args): - """makes training/val/test""" +def make_loaders(args, create_dataset_function): + """makes training/val/test + Args: + args.train_data, args.valid_data, args.test_data: str. Paths to the dataset. + args.split: str. format: "8,1,1". how to split train_data. + args.dataset_type: use to create the right datasets. + """ + make_dataset = partial(make_dataset_full, + create_dataset_function=create_dataset_function) world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group()) @@ -290,22 +285,3 @@ class RandomMappingDataset(data.Dataset): rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)]) index = rng.randint(len(self.wrapped_data)) return self.wrapped_data[index] - -def detect_new_datasets(args): - if args.new_dataset_path is None: - return None - if not os.path.exists(args.new_dataset_path): - print('Warning: new_dataset_path not exists... skip detection.') - return None - current_datasets = [str(os.path.abspath(path)) for path in args.train_data] - - found = [] - for _p in os.listdir(args.new_dataset_path): - p = os.path.join(args.new_dataset_path, _p) - if (str(p).endswith('lmdb') or str(p).endswith('bin')) and not str(os.path.abspath(p)) in current_datasets: - found.append(p) - if len(found) == 0: - return None - else: - args.train_data = args.train_data + found - return make_loaders(args) diff --git a/data_utils/datasets.py b/data_utils/datasets.py index d12d8f7..11836ad 100755 --- a/data_utils/datasets.py +++ b/data_utils/datasets.py @@ -11,26 +11,18 @@ import os import sys import math import random -from tqdm import tqdm import logging import numpy as np -import torch -import torch.nn.functional as F -from torchvision import datasets, transforms import pickle -from collections import namedtuple from torch.utils.data import Dataset -import lmdb -from .unified_tokenizer import get_tokenizer -from .templates import TextCodeTemplate logger = logging.getLogger(__name__) - +import lmdb class LMDBDataset(Dataset): def __init__(self, path, process_fn): self.env = lmdb.open( @@ -55,9 +47,7 @@ class LMDBDataset(Dataset): with self.env.begin(write=False) as txn: key = str(idx).encode('utf-8') - row = pickle.loads(txn.get(key)) - return self.process_fn(row) class BinaryDataset(Dataset): @@ -80,70 +70,3 @@ class BinaryDataset(Dataset): def __getitem__(self, index): return self.process_fn(self.bin[index]) -def get_dataset_by_type(dataset_type, path: str, args, DS_CLASS=LMDBDataset): - kwargs_to_dataset = {} - - tokenizer = get_tokenizer() - if args.layout[-1] > args.max_position_embeddings: - ml = args.layout[-1] - else: - ml = args.max_position_embeddings - - def pad_to_len(ret): - if len(ret) < ml: # pad - return np.concatenate((ret, - np.array([tokenizer['[PAD]']] * (ml - len(ret)))), - axis=0), len(ret) - else: - if len(ret) > ml: - logger.warning('Out of max len, truncated.') - return ret[:ml], ml - - if dataset_type == 'TokenizedDataset': - # already tokenized when saved - def process_fn(row): - ret, attention_mask_sep = pad_to_len(row.flatten()) - return {'text': ret, - 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) - } - - elif dataset_type == 'TextCodeDataset': - def process_fn(row): - text, code = row[0], row[1].flatten() - ret = TextCodeTemplate(text, code) - ret, attention_mask_sep = pad_to_len(ret) - return {'text': ret, - 'loss_mask': np.array([1] * attention_mask_sep + [0] * (len(ret) - attention_mask_sep)) - } - - elif dataset_type == 'CompactBinaryDataset': - layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME - DS_CLASS = BinaryDataset - kwargs_to_dataset['length_per_sample'] = layout[-1] - def process_fn(row): - row = row.astype(np.int64) - - codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))] - - text = row[:layout[0]] - text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1] - n_pad = layout[0]-3-len(text) - parts = [ - np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64), - TextCodeTemplate(text, codes[1]), # FIXME - *codes[2:] # FIXME - ] - ret = np.concatenate(parts, axis=0) - return {'text': ret, - 'loss_mask': np.array([0] * (n_pad+1) + [1] * (len(ret) - n_pad - 1)) # don't predict [CLS] - } - elif dataset_type == 'BinaryDataset': - DS_CLASS = BinaryDataset - def process_fn(row): - loss_mask = (row >= 0).astype(np.int64) - return {'text': row.astype(np.int64), - 'loss_mask': loss_mask - } - - return DS_CLASS(path, process_fn, **kwargs_to_dataset) - diff --git a/data_utils/samplers.py b/data_utils/samplers.py index fa8474e..90aac94 100755 --- a/data_utils/samplers.py +++ b/data_utils/samplers.py @@ -22,7 +22,7 @@ from torch.utils import data import numpy as np class RandomSampler(data.sampler.Sampler): - r""" + r""" Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, but this class lets the user set an epoch like DistributedSampler Samples elements randomly. If without replacement, then sample from a shuffled dataset. @@ -139,14 +139,14 @@ class DistributedBatchSampler(data.sampler.BatchSampler): self.sampler.wrap_around -= (self.batch_size) self.wrap_around += (len(batch)) self.wrap_around %= self.batch_size - if isinstance(self.sampler, TransposedSampler): - for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)): - if i == 0: - continue - batch.append(idx) - new_batch_len = len(batch) - if len(batch) == self.batch_size: - break + # if isinstance(self.sampler, TransposedSampler): + # for i, idx in enumerate(self.data_iterator(self.sampler, wrap_around=True)): + # if i == 0: + # continue + # batch.append(idx) + # new_batch_len = len(batch) + # if len(batch) == self.batch_size: + # break yield self._batch(batch) if self.wrap_last: self.sampler.wrap_around += self.batch_size diff --git a/data_utils/sp_tokenizer.py b/data_utils/sp_tokenizer.py deleted file mode 100755 index 01ac73b..0000000 --- a/data_utils/sp_tokenizer.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -from https://github.com/openai/gpt-2/, changed for chinese -""" -import json -import os -import sentencepiece as spm - -""" -SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation -systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements -subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the -extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end -system that does not depend on language-specific pre/postprocessing. -https://github.com/google/sentencepiece - -pip install sentencepiece - -or git clone https://github.com/google/sentencepiece.git -python setup.py install - -""" -PRETRAINED_MODEL_FILE = "pretrained/chinese_sentencepiece/cog-pretrain.model" - - -def get_pairs(word): - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -class Encoder: - def __init__(self, encoder, bpe_merges): - self.encoder = encoder - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) - self.cache = {} - self.max_len = 0 - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token) - pairs = get_pairs(word) - if not pairs: - return token - - while True: - bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - return [self.encoder.get(token, 1) for token in self.tokenize(text)] - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - return text - - def tokenize(self, text): - bpe_tokens = [] - bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) - return bpe_tokens - - def convert_tokens_to_ids(self, tokens): - return [self.encoder.get(token, 1) for token in tokens] - - -class Encoder_SP: - def __init__(self, model_path): - self.sp = spm.SentencePieceProcessor() - self.sp.Load(model_path) - self.num_tokens = self.sp.vocab_size() - - def encode(self, text): - """ - text="...." - """ - return self.sp.EncodeAsIds(text) - - def decode(self, tokens): - """ - tokens=[x1,x2,...] - """ - text = [int(token) for token in tokens] - return self.sp.DecodeIds(text) - - def tokenize(self, text): - return self.sp.EncodeAsPieces(text) - - def convert_tokens_to_ids(self, tokens): - return [self.sp.PieceToId(token) for token in tokens] - - def convert_token_to_id(self, token): - return self.sp.PieceToId(token) - - def convert_id_to_token(self, idx): - return self.sp.IdToPiece(idx) - - -def get_encoder(encoder_file, bpe_file): - # 以下是为了åŒä¸€ä¸ªå‡½æ•°å…¥å…¼å®¹sentencepiece - filepath, filename = os.path.split(encoder_file) - shotname, extension = os.path.splitext(filename) - - if (".model" == extension) and (bpe_file == ""): - return Encoder_SP(encoder_file) - else: - with open(encoder_file, 'r', encoding="utf-8") as f: - encoder = json.load(f) - with open(bpe_file, 'r', encoding="utf-8") as f: - bpe_data = f.read() - bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] - return Encoder( - encoder=encoder, - bpe_merges=bpe_merges, - ) - - -def from_pretrained(): - return get_encoder(PRETRAINED_MODEL_FILE, "") \ No newline at end of file diff --git a/data_utils/templates.py b/data_utils/templates.py deleted file mode 100755 index d5d4f99..0000000 --- a/data_utils/templates.py +++ /dev/null @@ -1,83 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : templates.py -@Time : 2021/01/11 22:28:57 -@Author : Ming Ding -@Contact : dm18@mails.tsinghua.edu.cn -''' - -# here put the import lib -import os -import sys -import math -import random -from tqdm import tqdm - -import numpy as np -import torch -import torch.nn.functional as F - -from .unified_tokenizer import get_tokenizer -from .vqvae_tokenizer import sqrt_int - -def concat_codes(*codes): - is_numpy = is_tensor = False - for code in codes: - if isinstance(code, np.ndarray): - is_numpy = True - if isinstance(code, torch.Tensor): - is_tensor = True - device = code.device - if is_tensor: - return torch.cat( - [ - torch.tensor(code, device=device) - for code in codes - ] - ) - elif is_numpy: - return np.concatenate( - [ - np.array(code) - for code in codes - ], - axis=0 - ) - else: - ret = [] - for code in codes: - ret = ret + code - return ret - -def TextCodeTemplate(text, code): - tokenizer = get_tokenizer() - if isinstance(text, str): - text_ids = [tokenizer['[ROI1]']] + tokenizer(text) - else: - text_ids = np.concatenate( - ( - np.array([tokenizer['[ROI1]']]), - text, - ), - axis=0 - ) - code = tokenizer.wrap_code(code) - return concat_codes(text_ids, code) - -def Code2CodeTemplate(text, code0, code1): - tokenizer = get_tokenizer() - text_ids = tokenizer.parse_query(text) if isinstance(text, str) else text - code0 = tokenizer.wrap_code(code0) - code1 = tokenizer.wrap_code(code1, idx=2) - return concat_codes(text_ids, code0, code1) - -def PureTextTemplate(text): - tokenizer = get_tokenizer() - return tokenizer(text) + [tokenizer['[SEP]']] - - - - - - - diff --git a/data_utils/unified_tokenizer.py b/data_utils/unified_tokenizer.py deleted file mode 100755 index 1ce520b..0000000 --- a/data_utils/unified_tokenizer.py +++ /dev/null @@ -1,212 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : unified_tokenizer.py -@Time : 2021/01/11 16:36:33 -@Author : Ming Ding -@Contact : dm18@mails.tsinghua.edu.cn -''' - -# here put the import lib -import os -import sys -import math -import random -from tqdm import tqdm - -import numpy as np -import torch -import torch.nn.functional as F - -from .sp_tokenizer import from_pretrained -from .vqvae_tokenizer import VQVAETokenizer, sqrt_int - -class UnifiedTokenizer(object): - def __init__(self, img_tokenizer_path, device, img_tokenizer_num_tokens=None): - self.device = device - if img_tokenizer_path is None and img_tokenizer_num_tokens is not None: - # pretraining but only know the vocab size of VQVAE, which is developing fast - self.img_tokenizer = FakeTokenizer(img_tokenizer_num_tokens) - else: - self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device) - self.txt_tokenizer = from_pretrained() - self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens - self.raw_command_tokens = [ - ('[PAD]', 0), - ('[BOI1]', 1), # Begin - ('[BOI2]', 2), - ('[BOI3]', 3), - ('[EOI1]', 4), # End - ('[EOI2]', 5), - ('[EOI3]', 6), - ('[ROI1]', 7), # Reference - ('[ROI2]', 8), - ('[ROI3]', 9), - ('[SEP]', 10), - ('[MASK]', 11), - ('[CLS]', 12), - ('[ENC]', 13), - ('[TINY]', 14), # 8 * 8 - ('[SMALL]', 15), # 16 * 16 - ('[BASE]', 16), # 32 * 32 - ('[BIG]', 17), # 64 * 64 - ('[POS0]', 18), # 58210 - ('[POS1]', 19), - ('[POS2]', 20), - ('[POS3]', 21), - ('[POS4]', 22), - ('[POS5]', 23), - ('[POS6]', 24), - ('[POS7]', 25), - ('[POS8]', 26) - # Please leave the ``size tokens'' at the back of command tokens - ] - self.command_tokens = { - k: v + self.num_tokens - for k, v in self.raw_command_tokens - } - self.num_tokens += len(self.raw_command_tokens) - - def __getitem__(self, command_token): - return self.command_tokens[command_token] - - def __len__(self): - """total number of tokens""" - return self.num_tokens - - def __call__(self, inputs, process_fn=None): - """run preprocessing and encode inputs as Ids - CANNOT contain command tokens""" - if isinstance(inputs, torch.Tensor): # image - if len(inputs.shape) == 3: - inputs = inputs.unsqueeze(0) - return self.img_tokenizer.EncodeAsIds(inputs) - return self.EncodeAsIds(inputs, process_fn=process_fn) - - def EncodeAsIds(self, text, process_fn=None): - processed_text = text - if process_fn is not None: - processed_text = process_fn(processed_text) - ids = self.txt_tokenizer.encode(processed_text) - return [x + self.img_tokenizer.num_tokens for x in ids] - - def DecodeIds(self, ids): - ret, img_buffer, txt_buffer, ret_imgs = [], [], [], [] - try: - for x in ids: - if self.num_tokens - len(self.raw_command_tokens) <= x: - # command tokens - token = self.raw_command_tokens[x - (self.num_tokens - len(self.raw_command_tokens))][0] - if token.startswith('[EOI') and len(img_buffer) > 0: - # dump image - ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer)) - img_buffer = [] - if len(txt_buffer) > 0: - # dump text - ret.append(self.txt_tokenizer.decode(txt_buffer)) - txt_buffer = [] - ret.append(token) - elif x < self.img_tokenizer.num_tokens: - img_buffer.append(x) - else: - txt_buffer.append(x - self.img_tokenizer.num_tokens) - - if len(img_buffer) > 0: - # dump image - ret_imgs.append(self.img_tokenizer.DecodeIds(img_buffer)) - img_buffer = [] - if len(txt_buffer) > 0: - # dump text - ret.append(self.txt_tokenizer.decode(txt_buffer)) - txt_buffer = [] - except ValueError: - print('Value error in tokenization, skipping...') - return ret, ret_imgs - - def wrap_code(self, code, idx=1): - s = sqrt_int(len(code)) - prefix = {8:'[TINY]', 16:'[SMALL]', 32:'[BASE]', 64:'[BIG]'}[s] - boi = {1:'[BOI1]', 2: '[BOI2]', 3:'[BOI3]'}[idx] - eoi = {1:'[EOI1]', 2: '[EOI2]', 3:'[EOI3]'}[idx] - - if isinstance(code, list): - return [self.command_tokens[prefix], self.command_tokens[boi]] + \ - code + [self.command_tokens[eoi]] - elif isinstance(code, np.ndarray): - return np.concatenate( - ( - np.array([self.command_tokens[prefix], self.command_tokens[boi]]), - code, - np.array([self.command_tokens[eoi]]) - ), - axis=0 - ) - elif isinstance(code, torch.Tensor): - return torch.cat( - ( - torch.tensor([self.command_tokens[prefix], self.command_tokens[boi]]), - code, - np.array([self.command_tokens[eoi]]) - ) - ) - else: - raise ValueError('') - - def parse_query(self, query, img_size=256): - text_buffer = [] - ret = [] - for part in query.split(' '): - if part in self.command_tokens: - if len(text_buffer) > 0: - # dump text ids - ret.extend(self.EncodeAsIds(' '.join(text_buffer))) - text_buffer = [] - if part == '[MASK]': - ret.append(-1) - else: - ret.append(self.command_tokens[part]) - elif part.startswith('[MASK]*'): # special lang *N - c = int(part[7:]) - assert c > 0 - if len(text_buffer) > 0: - # dump text ids - ret.extend(self.EncodeAsIds(' '.join(text_buffer))) - text_buffer = [] - ret.extend([-1] * c) - elif part.startswith('[Image'): # [Image*N]path - c = part[6:] - assert len(c) > 0 - num_codes, img_path = c.split(']') - if num_codes == '': - num_codes = 1024 - else: - num_codes = int(num_codes) - - raw_img = self.img_tokenizer.read_img(img_path, img_size=img_size) - img_codes = self.img_tokenizer.EncodeAsIds(raw_img) # [1, 32*32] - img_codes[0, num_codes:] = -1 - img_codes = img_codes[0].tolist() - ret.extend(img_codes) - else: - text_buffer.append(part) - - if len(text_buffer) > 0: - # dump text ids - ret.extend(self.EncodeAsIds(' '.join(text_buffer))) - text_buffer = [] - return ret - -def get_tokenizer(args=None): - if not hasattr(get_tokenizer, 'tokenizer'): - # the first time to load the tokenizer, specify img_tokenizer_path - get_tokenizer.tokenizer = UnifiedTokenizer( - args.img_tokenizer_path, - device=torch.cuda.current_device(), - img_tokenizer_num_tokens=args.img_tokenizer_num_tokens - ) - return get_tokenizer.tokenizer - -class FakeTokenizer(object): - def __init__(self, num_tokens): - self.num_tokens = num_tokens - def __len__(self): - return self.num_tokens \ No newline at end of file diff --git a/data_utils/vqvae_tokenizer.py b/data_utils/vqvae_tokenizer.py deleted file mode 100755 index 56ee251..0000000 --- a/data_utils/vqvae_tokenizer.py +++ /dev/null @@ -1,86 +0,0 @@ -# -*- encoding: utf-8 -*- -''' -@File : vqvae_tokenizer.py -@Time : 2021/01/11 17:57:43 -@Author : Ming Ding -@Contact : dm18@mails.tsinghua.edu.cn -''' - -# here put the import lib -import os -import sys -import math -import random -from tqdm import tqdm - -import numpy as np -import torch -import torch.nn.functional as F - - -from vqvae import new_model, img2code, code2img -from torchvision import transforms -from PIL import Image - -def is_exp2(x): - t = math.log2(x) - return abs(t - int(t)) < 1e-4 -def sqrt_int(x): - r = int(math.sqrt(x) + 1e-4) - assert r * r == x - return r - -class VQVAETokenizer(object): - def __init__(self, - model_path, - device='cuda' - ): - ckpt = torch.load(model_path, map_location=torch.device(device)) - - model = new_model() - - if list(ckpt.keys())[0].startswith('module.'): - ckpt = {k[7:]: v for k, v in ckpt.items()} - - model.load_state_dict(ckpt) - model = model.to(device) - model.eval() - - self.model = model - self.device = device - self.image_tokens = model.quantize_t.n_embed - self.num_tokens = model.quantize_t.n_embed - self.tr_normalize = transforms.Normalize([0.79093, 0.76271, 0.75340], [0.30379, 0.32279, 0.32800]) - - def __len__(self): - return self.num_tokens - - def EncodeAsIds(self, img, add_normalization=False): - assert len(img.shape) == 4 # [b, c, h, w] - if add_normalization: - img = self.tr_normalize(img) - return img2code(self.model, img) - - def DecodeIds(self, code, shape=None): - if shape is None: - if isinstance(code, list): - code = torch.tensor(code, device=self.device) - s = sqrt_int(len(code.view(-1))) - assert s * s == len(code.view(-1)) - shape = (1, s, s) - code = code.view(shape) - out = code2img(self.model, code) - return out - - def read_img(self, path, img_size=256): - tr = transforms.Compose([ - transforms.Resize(img_size), - transforms.CenterCrop(img_size), - transforms.ToTensor(), - ]) - img = tr(Image.open(path)) - if img.shape[0] == 4: - img = img[:-1] - img = self.tr_normalize(img) - img = img.unsqueeze(0).float().to(self.device) # size [1, 3, h, w] - return img \ No newline at end of file diff --git a/mpu/initialize.py b/mpu/initialize.py index 0a3e15a..3597c2c 100755 --- a/mpu/initialize.py +++ b/mpu/initialize.py @@ -36,7 +36,7 @@ def initialize_model_parallel(model_parallel_size_): Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model. The present function will - create 4 model parallel groups and 2 data parallel grous as: + create 4 model parallel groups and 2 data parallel groups as: 4 model parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 data parallel groups: diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py index 3b8ccc4..a8a0ea2 100755 --- a/pretrain_gpt2.py +++ b/pretrain_gpt2.py @@ -12,11 +12,15 @@ import sys import math import random import torch +import numpy as np import mpu from arguments import get_args from model.base_model import BaseModel from training.deepspeed_training import main +from data_utils import BinaryDataset +from tokenization import get_tokenizer +from tokenization.cogview import TextCodeTemplate def get_masks_and_position_ids(data, loss_mask=None, @@ -99,6 +103,27 @@ def forward_step(data_iterator, model, args, timers): return loss, {} +def create_dataset_function(path, args): + tokenizer = get_tokenizer() + layout = [64, 64+16**2, 64+16**2+32**2, 64+64**2+16**2+32**2] # FIXME + def process_fn(row): + row = row.astype(np.int64) + codes = [row[layout[i-1]:layout[i]] for i in range(1, len(layout))] + + text = row[:layout[0]] + text = text[text>0][:layout[0] - 3] # [CLS] [BASE] [ROI1] + merged = TextCodeTemplate(text, codes[1], tokenizer) + n_pad = args.max_sequence_length - len(merged) + parts = [ + merged, + np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64) + ] + ret = np.concatenate(parts, axis=0) + return {'text': ret, + 'loss_mask': np.array([1]*len(merged) + [0]*n_pad) + } + return BinaryDataset(path, process_fn, length_per_sample=layout[-1]) + if __name__ == '__main__': args = get_args() - main(args, model_cls=BaseModel, forward_step=forward_step) + main(args, model_cls=BaseModel, forward_step=forward_step, create_dataset_function=create_dataset_function) diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py index 95d8ac4..5519003 100644 --- a/training/deepspeed_training.py +++ b/training/deepspeed_training.py @@ -40,11 +40,12 @@ from data_utils import make_loaders, get_tokenizer -def main(args, model_cls, forward_step, init_step=None): +def main(args, model_cls, forward_step_function, create_dataset_function, init_function=None): """Main training program.""" hooks = { - 'forward_step': forward_step, - 'init_step': init_step + 'forward_step': forward_step_function, + 'init_function': init_function, + 'create_dataset_function': create_dataset_function } torch.backends.cuda.matmul.allow_tf32 = False @@ -63,7 +64,7 @@ def main(args, model_cls, forward_step, init_step=None): # init tokenizer tokenizer = get_tokenizer(args) # Data stuff. - train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args) + train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args, hooks['create_dataset_function']) # Model, optimizer, and learning rate. model, optimizer = setup_model_and_optimizer(args, model_cls) @@ -109,8 +110,8 @@ def main(args, model_cls, forward_step, init_step=None): val_data_iterator = None # init hook before training - if hooks['init_func'] is not None: - hooks['init_func'](args, model, optimizer) + if hooks['init_function'] is not None: + hooks['init_function'](args, model, optimizer) # training iteration = 0 @@ -525,14 +526,14 @@ def set_random_seed(seed): mpu.model_parallel_cuda_manual_seed(seed) -def get_train_val_test_data(args): +def get_train_val_test_data(args, create_dataset_function): """Load the data on rank zero and boradcast number of tokens to all GPUS.""" (train_data, val_data, test_data) = (None, None, None) # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0: - train_data, val_data, test_data = make_loaders(args) + train_data, val_data, test_data = make_loaders(args, create_dataset_function) num_tokens = get_tokenizer().num_tokens before = num_tokens -- GitLab