diff --git a/arguments.py b/arguments.py index 7344533a58f37189ec2d76ec5fbd6e8ee1f93876..f87d809bd6229201ae25f97c82b76a9a604f7821 100755 --- a/arguments.py +++ b/arguments.py @@ -157,27 +157,15 @@ def add_text_generate_args(parser): group.add_argument("--temperature", type=float, default=1.0) group.add_argument("--top_p", type=float, default=0.0) group.add_argument("--top_k", type=int, default=0) - # group.add_argument("--out-seq-length", type=int, default=256) - group.add_argument("--generation-task", type=str, - default='text2image', - choices=['text2image', - 'image2text', - 'super-resolution', - 'low-level super-resolution', - 'post-selection', - 'raw', - 'cuda-2d generation' - ], - help='what type of inference task to use') + group.add_argument("--out-seq-length", type=int, default=256) group.add_argument('--input-source', type=str, default='interactive', help='what input mode to use, interactive or path') group.add_argument('--output-path', type=str, default='./samples', help='path to place the generated samples') - group.add_argument('--debug', action='store_true', - help='Debug will merge all outputs.') group.add_argument('--with-id', action='store_true', help='If each line is prepended with an id.') group.add_argument('--max-inference-batch-size', type=int, default=12) + group.add_argument('--device', type=int, default=0) return parser @@ -218,7 +206,6 @@ def add_generation_api_args(parser): group.add_argument('--input_rec_path', default='input/') group.add_argument('--check_mode', default='code') group.add_argument('--time_interval', default=10) - group.add_argument('--device', default=None) return parser diff --git a/generation/__init__.py b/generation/__init__.py index 90c43c6e006fbf1ad1e6228ef38df90e7bf7e525..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100755 --- a/generation/__init__.py +++ b/generation/__init__.py @@ -1,3 +0,0 @@ -from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score -from .magnify import magnify -from .cuda_2d_sampling import filling_sequence_cuda_2d \ No newline at end of file diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py index f1a047ccbea954fb3bad78be2e8f8c0889f470fa..b29964f0a2fbd77945347a26a3cbfead95de7a8c 100644 --- a/generation/autoregressive_sampling.py +++ b/generation/autoregressive_sampling.py @@ -37,6 +37,8 @@ def update_mems(hiddens, mems, max_memory_length): if new_memory_length <= query_length: new_mems.append(hiddens[i][:, -new_memory_length:]) else: + if mems[i].shape[0] < hiddens[i].shape[0]: + mems[i] = mems[i].expand(hiddens[i].shape[0], *mems[i].shape[1:]) new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1)) return new_mems @@ -45,8 +47,9 @@ def filling_sequence( model, seq, batch_size, + strategy=BaseStrategy(), max_memory_length=100000, - strategy=BaseStrategy() + log_attention_weights=None ): ''' seq: [2, 3, 5, ..., -1(to be generated), -1, ...] @@ -60,12 +63,12 @@ def filling_sequence( assert context_length > 0 tokens, attention_mask, position_ids = get_masks_and_position_ids(seq) tokens = tokens[..., :context_length] - + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 # initialize generation counter = context_length - 1 # Last fixed index is ``counter'' index = 0 # Next forward starting index, also the length of cache. mems = [] # mems are the first-level citizens here, but we don't assume what is memorized. - + # step-by-step generation while counter < len(seq) - 1: # Now, we want to generate seq[counter + 1], @@ -82,18 +85,20 @@ def filling_sequence( continue # forward + model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen logits, *mem_kv = model( tokens[:, index:], position_ids[..., index: counter+1], - attention_mask[..., index: counter+1, :counter+1], # TODO mem + attention_mask[..., index: counter+1, :counter+1], # TODO memlen *mems ) mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) counter += 1 index = counter - # sampling logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size] + tokens = tokens.expand(batch_size, -1) tokens, mems = strategy.forward(logits, tokens, mems) + model.log_attention_weights = None return tokens \ No newline at end of file diff --git a/generation/cuda2d_sampling.py b/generation/cuda2d_sampling.py index bd755524ebcde00180058d111bfff13b448a9f21..20ffa38242c66421d28b8cdb84abb53373d51ec9 100644 --- a/generation/cuda2d_sampling.py +++ b/generation/cuda2d_sampling.py @@ -12,7 +12,7 @@ import sys import math import random import torch -from .sampling_strategies import BaseStrategy +from .sampling_strategies import IterativeEntfilterStrategy def filling_sequence( model, @@ -20,11 +20,15 @@ def filling_sequence( seq1, warmup_steps=3, block_hw=(4, 4), - strategy=BaseStrategy(topk=10) + strategy=IterativeEntfilterStrategy(topk=10) ): ''' seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] - 4095 {layout[2]} final_token + 4095 {layout[2]} final_token. + + Attention: + The sampling temperature are changing, temporally we hard code them here. + The temperature in the strategy is not used. ''' assert hasattr(model, 'layout') layout = model.layout @@ -46,7 +50,7 @@ def filling_sequence( assert seq.shape[1] == layout[-1] + 1 # build initial tokens, attention_mask, and position_ids - tokens = seq[:, :-1].clone() + tokens = seq.clone() attention_mask = torch.ones(layout[1], layout[1]).tril().to(device) attention_mask[n_pad:, :n_pad] = 0 position_ids = torch.cat(( @@ -54,104 +58,32 @@ def filling_sequence( torch.arange(0, layout[1] - n_pad), torch.arange(0, layout[2]-layout[1]))).to(device) - # iterative refining + # prepare for interation + unfixed = (tokens < 0) + unfixed[:, -4096] = True ll, rr = block_hw + edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4) num_steps = warmup_steps + ll + rr - 2 - for step_cnt in range(num_steps): - logits, *_dump = model(tokens, position_ids, attention_mask) - - # warmup - real_topk = 10 - warmup_steps = 3 - iterative_step= warmup_steps + 6 + # interative refining + for step_cnt in range(1, num_steps+1): + logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask) if step_cnt <= warmup_steps: real_temp = 0.1 - elif step_cnt == warmup_steps + 1: - real_temp = 0.55 - elif step_cnt > warmup_steps + 1: - real_temp = 0.45 - # if 5 < step_cnt: - # real_topk = 200 - # sampling - for invalid_slice in invalid_slices: # forbide to generate other tokens - logits[..., invalid_slice] = -float('Inf') - assert args.top_k > 0 - - # probs0 = F.softmax(logits/real_temp, dim=-1) - topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1) - ent = -(topraw * topraw.log()).sum(dim=-1) - # topsum = topraw.sum(dim=-1) - if step_cnt > warmup_steps: - # import pdb;pdb.set_trace() - real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6 - # import pdb;pdb.set_trace() + tokens = strategy.forward(logits, tokens, real_temp) else: - real_temp2 = real_temp - # import pdb;pdb.set_trace() - probs = F.softmax(logits/real_temp2, dim=-1) - tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1) - prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) - edge_idx = tk_idx[:, :, -1:] - edge_value = tk_value[:, :, -1:] - edge_mask = probs.gather(dim=-1, index=prev) < edge_value - prev[edge_mask] = edge_idx[edge_mask] - prev.squeeze_(-1) - # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1]) - # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1) - # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) - # update unfixed - choice = 1 - if choice == 0 and warmup_steps < step_cnt: - mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) - # import pdb;pdb.set_trace() - dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) - - new_fixed = unfixed.clone() - moved_new_fixed = new_fixed[:, 2:] - moved_new_fixed &= dprob - moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not() - moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not() - # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not() - moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() - moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() - # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() - elif choice == 1 and warmup_steps < step_cnt: - new_fixed = unfixed & False - ll, rr = 4, 4 + real_temp = 1.05 + new_tokens = strategy.forward( + logits, tokens, real_temp, + entfilter=1.3, + filter_topk=5, + temperature2=0.6 + ) + tokens[unfixed] = new_tokens[unfixed] + # fixed tokens (update unfixed) for x in range(min(ll, step_cnt - warmup_steps)): y = step_cnt - warmup_steps - x - 1 if y < rr: print(x,y) - new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True - new_fixed &= unfixed - else: - new_fixed = unfixed & False # TODO - new_fixed[:, -1] = True - - # with open(f'bed{step_cnt}.txt', 'w') as fout: - # for i, prob in enumerate(topraw[0, -4096:]): - # s = ' '.join([str(x) for x in prob.tolist()]) - # fout.write(f'{i} {s}\n') - - unfixed &= new_fixed.logical_not() - # update seq and tokens - seq[new_fixed] = prev[new_fixed[:, 1:]] - tokens = seq[:, :-1].clone() - tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]] - - if step_cnt == iterative_step: - seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step - n_unfixed = unfixed.sum(dim=-1).tolist() - print(f'Exit with {n_unfixed} unfixed tokens.') - break - if args.debug: - from torchvision.utils import save_image - seqt = seq.clone() - seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step - imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt]) - if args.debug: - imgs = torch.cat(imgs, dim=0) - save_image(imgs, f'steps{device}.jpg', normalize=True) - model.module.transformer.max_memory_length = args.max_memory_length - - return seq \ No newline at end of file + unfixed[..., -(layout[-1] - layout[-2]):].view( + batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False + return tokens \ No newline at end of file diff --git a/generation/cuda_2d_sampling.py b/generation/cuda_2d_sampling.py index 21e8dcc2707ea067a6e99b72e988347f15ad8395..7cd04402897e113fcc4931aebb006e1c21fb7294 100644 --- a/generation/cuda_2d_sampling.py +++ b/generation/cuda_2d_sampling.py @@ -116,19 +116,20 @@ def filling_sequence_cuda_2d( # update unfixed choice = 1 if choice == 0 and warmup_steps < step_cnt: - mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) - # import pdb;pdb.set_trace() - dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) + # mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) + # # import pdb;pdb.set_trace() + # dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) - new_fixed = unfixed.clone() - moved_new_fixed = new_fixed[:, 2:] - moved_new_fixed &= dprob - moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not() - moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not() - # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not() - moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() - moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() - # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() + # new_fixed = unfixed.clone() + # moved_new_fixed = new_fixed[:, 2:] + # moved_new_fixed &= dprob + # moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not() + # moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not() + # # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not() + # moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() + # moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() + # # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() + pass elif choice == 1 and warmup_steps < step_cnt: new_fixed = unfixed & False ll, rr = 4, 4 diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py index f1096032cb3c88f2d3ba041d1974fe28a88bf43a..2f71e09703c38106088167808d3be758ae8c9b24 100644 --- a/generation/sampling_strategies/__init__.py +++ b/generation/sampling_strategies/__init__.py @@ -1 +1,2 @@ -from .base_strategy import BaseStrategy \ No newline at end of file +from .base_strategy import BaseStrategy +from .iterative_entfilter_strategy import IterativeEntfilterStrategy \ No newline at end of file diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py index 99a0a74e6e54b6972866d0fb664c079cac75a170..e46a8ca4e505c2891bae2599ebe10746cc54d876 100644 --- a/generation/sampling_strategies/base_strategy.py +++ b/generation/sampling_strategies/base_strategy.py @@ -20,27 +20,20 @@ def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): return logits class BaseStrategy: - def __init__(self, invalid_slices=[], temperature=1., topk=200, debias=False): + def __init__(self, invalid_slices=[], temperature=1., topk=200, eps=1e-4): self.invalid_slices = invalid_slices self.temperature = temperature self.topk = topk - self.debias = debias + self.eps = eps def forward(self, logits, tokens, mems, temperature=None): if temperature is None: temperature = self.temperature logits = logits / temperature for invalid_slice in self.invalid_slices: - logits[..., invalid_slice] = -float('Inf') - if self.debias: - probs = F.softmax(logits, dim=-1) - tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1) - pred = torch.multinomial(probs, num_samples=1) - for j in range(0, pred.shape[0]): - if probs[j, pred[j,-1]] < tk_value[j, -1]: - pred[j, -1] = tk_idx[j, torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))] # 100 is the last N as outlier, which is chosen casually - else: - logits = top_k_logits_(logits) - probs = F.softmax(logits, dim=-1) - pred = torch.multinomial(probs, num_samples=1) + logits[..., invalid_slice] = -65504 + + logits = top_k_logits_(logits, self.topk) + probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch + pred = torch.multinomial(probs, num_samples=1) tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) return tokens, mems diff --git a/generation/sampling_strategies/iterative_entfilter_strategy.py b/generation/sampling_strategies/iterative_entfilter_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..84196d0c8fae8db48ed515ff265dea9e22136409 --- /dev/null +++ b/generation/sampling_strategies/iterative_entfilter_strategy.py @@ -0,0 +1,55 @@ +# -*- encoding: utf-8 -*- +''' +@File : iterative_entfilter_strategy.py +@Time : 2021/10/09 14:32:29 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random + +import torch +import torch.nn.functional as F + +def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + +class IterativeEntfilterStrategy: + def __init__(self, invalid_slices=[], temperature=1., topk=10): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + + def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None): + # In interative strategy, logits are of shape [batch_size, seq_length, hidden_size] + if temperature is None: + temperature = self.temperature + # check entropy filter + if entfilter is not None: + assert temperature2 is not None + topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1) + ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length] + temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2 + logits = logits / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float('Inf') + + # debiased topk + probs = F.softmax(logits, dim=-1) + tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1) + pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + edge_idx = tk_idx[:, :, -1:] + edge_value = tk_value[:, :, -1:] + edge_mask = probs.gather(dim=-1, index=pred) < edge_value + pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token + pred.squeeze_(-1) # [batch_size, seq_length] + + assert tokens.shape[1] == pred.shape[1] + 1 + tokens = torch.cat((tokens[:, :1], pred), dim=1) + return tokens \ No newline at end of file diff --git a/generation/utils.py b/generation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bad073413b6c3efa6f71b3e46eeacb883bd504b0 --- /dev/null +++ b/generation/utils.py @@ -0,0 +1,78 @@ +# -*- encoding: utf-8 -*- +''' +@File : utils.py +@Time : 2021/10/09 17:18:26 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import time +import stat +from datetime import datetime +from torchvision.utils import save_image +import torch.distributed as dist + + +def timed_name(prefix, suffix=None, path=None): + return os.path.join( + path, + f"{prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}{suffix}" + ) + +def save_multiple_images(imgs, path, debug=True): + # imgs: list of tensor images + if debug: + imgs = torch.cat(imgs, dim=0) + print("\nSave to: ", path, flush=True) + save_image(imgs, path, normalize=True) + else: + print("\nSave to: ", path, flush=True) + for i in range(len(imgs)): + save_image(imgs[i], os.path.join(path, f'{i}.jpg'), normalize=True) + os.chmod(os.path.join(path,f'{i}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) + save_image(torch.cat(imgs, dim=0), os.path.join(path,f'concat.jpg'), normalize=True) + os.chmod(os.path.join(path,f'concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) + +def generate_continually(func, input_source='interactive'): + if input_source == 'interactive': + while True: + raw_text = input("\nPlease Input Query (stop to exit) >>> ") + raw_text = raw_text.strip() + if not raw_text: + print('Query should not be empty!') + continue + if raw_text == "stop": + return + try: + start_time = time.time() + func(raw_text) + print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) + except (ValueError, FileNotFoundError) as e: + print(e) + continue + else: + with open(input_source, 'r') as fin: + inputs = fin.readlines() + err_linenos = [] + for line_no, raw_text in enumerate(inputs): + if line_no % dist.get_world_size() != dist.get_rank(): + continue + rk = dist.get_rank() + print(f'Working on No. {line_no} on {rk}... ') + raw_text = raw_text.strip() + if len(raw_text) == 0: + continue + try: + start_time = time.time() + func(raw_text) + print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) + except (ValueError, FileNotFoundError) as e: + err_linenos.append(line_no) + continue + print(err_linenos) diff --git a/inference_cogview.py b/inference_cogview.py new file mode 100644 index 0000000000000000000000000000000000000000..56bdd96ead78f155fce08552e6c582facbeffbdc --- /dev/null +++ b/inference_cogview.py @@ -0,0 +1,97 @@ +# -*- encoding: utf-8 -*- +''' +@File : inference_cogview.py +@Time : 2021/10/09 19:41:58 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import argparse + +from arguments import get_args +from model.cached_autoregressive_model import CachedAutoregressiveModel +from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer +from tokenization import get_tokenizer +from generation.sampling_strategies import BaseStrategy +from generation.autoregressive_sampling import filling_sequence +from generation.utils import timed_name, save_multiple_images, generate_continually + +def main(args): + initialize_distributed(args) + tokenizer = prepare_tokenizer(args) + # build model + model = CachedAutoregressiveModel(args) + if args.fp16: + model = model.half() + model = model.to(args.device) + load_checkpoint(model, args) + set_random_seed(args.seed) + + # define function for each query + query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}' + invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] + strategy = BaseStrategy(invalid_slices, + temperature=args.temperature, topk=args.top_k) + + def process(raw_text): + if args.with_id: + query_id, raw_text = raw_text.split() + print('raw text: ', raw_text) + text = query_template.format(raw_text) + seq = tokenizer.parse_query(text, img_size=args.img_size) + if len(seq) > 1088: + raise ValueError('text too long.') + # calibrate text length + txt_len = seq.index(tokenizer['[BASE]']) + log_attention_weights = torch.zeros(len(seq), len(seq), + device=args.device, dtype=torch.half if args.fp16 else torch.float32) + log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args + # generation + seq = torch.cuda.LongTensor(seq, device=args.device) + mbz = args.max_inference_batch_size + assert args.batch_size < mbz or args.batch_size % mbz == 0 + output_list = [] + for tim in range(max(args.batch_size // mbz, 1)): + output_list.append( + filling_sequence(model, seq.clone(), + batch_size=min(args.batch_size, mbz), + strategy=strategy, + log_attention_weights=log_attention_weights + ) + ) + output_tokens = torch.cat(output_list, dim=0) + # decoding + imgs, txts = [], [] + for seq in output_tokens: + decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist()) + imgs.append(decoded_imgs[-1]) # only the last image (target) + # save + if args.with_id: + full_path = os.path.join(args.output_path, query_id) + os.makedirs(full_path, exist_ok=True) + save_multiple_images(imgs, full_path, False) + else: + prefix = raw_text.replace('/', '')[:20] + full_path = timed_name(prefix, '.jpg', args.output_path) + save_multiple_images(imgs, full_path, True) + + os.makedirs(args.output_path, exist_ok=True) + generate_continually(process, args.input_source) + +if __name__ == "__main__": + py_parser = argparse.ArgumentParser(add_help=False) + py_parser.add_argument('--full-query', action='store_true') + py_parser.add_argument('--img-size', type=int, default=256) + + known, args_list = py_parser.parse_known_args() + args = get_args(args_list) + args = argparse.Namespace(**vars(args), **vars(known)) + + with torch.no_grad(): + main(args) \ No newline at end of file diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py index 225ff3e6f9ed6d83a34676a3918a30e4a16ba70f..1d19fb82569103aed762d25463d50b014a760c2b 100755 --- a/model/cached_autoregressive_model.py +++ b/model/cached_autoregressive_model.py @@ -37,14 +37,15 @@ class CachedAutoregressiveModel(BaseModel): mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) # same as training - query_layer = self._transpose_for_scores(mixed_query_layer) - key_layer = self._transpose_for_scores(mixed_key_layer) - value_layer = self._transpose_for_scores(mixed_value_layer) - context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn=None, log_attention_weights=self.log_attention_weights) + query_layer = attn_module._transpose_for_scores(mixed_query_layer) + key_layer = attn_module._transpose_for_scores(mixed_key_layer) + value_layer = attn_module._transpose_for_scores(mixed_value_layer) + context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=self.log_attention_weights) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) - output = self.dense(context_layer) + output = attn_module.dense(context_layer) # new mem this layer new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous() diff --git a/scripts/text2image_cogview.sh b/scripts/text2image_cogview.sh new file mode 100755 index 0000000000000000000000000000000000000000..4fc46365914d359466343eb0185c1c147d4c7155 --- /dev/null +++ b/scripts/text2image_cogview.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +CHECKPOINT_PATH=pretrained/cogview/cogview-base +NLAYERS=48 +NHIDDEN=2560 +NATT=40 +MAXSEQLEN=1089 +MASTER_PORT=$(shuf -n 1 -i 10000-65535) +MPSIZE=1 + +#SAMPLING ARGS +TEMP=1.03 +TOPK=200 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) + +MASTER_PORT=${MASTER_PORT} python inference_cogview.py \ + --tokenizer-type cogview \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ + --mode inference \ + --distributed-backend nccl \ + --max-sequence-length 1089 \ + --sandwich-ln \ + --fp16 \ + --model-parallel-size $MPSIZE \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --load $CHECKPOINT_PATH \ + --num-attention-heads $NATT \ + --temperature $TEMP \ + --top_k $TOPK \ + --sandwich-ln \ + --input-source ./input.txt \ + --output-path samples_text2image \ + --batch-size 8 \ + --max-inference-batch-size 8 \ + --device 0 \ + $@ + + diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0337ffc828400323c95991e8177fdcf2ff7201 --- /dev/null +++ b/training/__init__.py @@ -0,0 +1,2 @@ +from .deepspeed_training import initialize_distributed, set_random_seed, prepare_tokenizer +from .model_io import load_checkpoint \ No newline at end of file diff --git a/training/deepspeed_training.py b/training/deepspeed_training.py index cff164c1c2c5512f4b7f9361022f104a7ea9c747..80d0e422fac82ddb72d3a878bc8f22caccf7163d 100644 --- a/training/deepspeed_training.py +++ b/training/deepspeed_training.py @@ -59,14 +59,13 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio else: args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M") - # Pytorch distributed. + # Pytorch distributed. must before seed initialize_distributed(args) set_random_seed(args.seed) # Random seeds for reproducability. - # init tokenizer - tokenizer = get_tokenizer(args) + prepare_tokenizer(args) # args.vocab_size is set. # Data stuff. - train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args, hooks['create_dataset_function']) + train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function']) # Model, optimizer, and learning rate. model, optimizer = setup_model_and_optimizer(args, model_cls) @@ -514,72 +513,39 @@ def initialize_distributed(args): # Optional DeepSpeed Activation Checkpointing Features if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing: - set_deepspeed_activation_checkpointing(args) + set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed def set_random_seed(seed): """Set random seed for reproducability.""" - if seed is not None and seed > 0: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - mpu.model_parallel_cuda_manual_seed(seed) - - -def get_train_val_test_data(args, create_dataset_function): - """Load the data on rank zero and boradcast number of tokens to all GPUS.""" - - (train_data, val_data, test_data) = (None, None, None) - - # Data loader only on rank 0 of each model parallel group. - if mpu.get_model_parallel_rank() == 0: - train_data, val_data, test_data = make_loaders(args, create_dataset_function) - num_tokens = get_tokenizer().num_tokens - - before = num_tokens - after = before - multiple = args.make_vocab_size_divisible_by * \ - mpu.get_model_parallel_world_size() - while (after % multiple) != 0: - after += 1 - print_rank_0('> padded vocab (size: {}) with {} dummy ' - 'tokens (new size: {})'.format( - before, after - before, after)) - token_counts = torch.cuda.LongTensor( - [after, int(args.do_train), int(args.do_valid), int(args.do_test)]) - else: - token_counts = torch.cuda.LongTensor([0, 0, 0, 0]) - # Broadcast num tokens. - torch.distributed.broadcast(token_counts, - mpu.get_model_parallel_src_rank(), - group=mpu.get_model_parallel_group()) - num_tokens = token_counts[0].item() - args.do_train = token_counts[1].item() - args.do_valid = token_counts[2].item() - args.do_test = token_counts[3].item() - - return train_data, val_data, test_data, num_tokens - -def see_memory_usage(message, force=False): - if not force: - return - dist.barrier() - if dist.get_rank() == 0: - print(message) - print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes") - print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes") - print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes") - print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes") - print(" ") - -def seed_torch(seed=1029): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.enabled = False + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enabled = False + torch.backends.cuda.matmul.allow_tf32 = False + if hasattr(mpu, 'model_parallel_cuda_manual_seed'): + mpu.model_parallel_cuda_manual_seed(seed) + + +def prepare_tokenizer(args): + tokenizer = get_tokenizer(args) + num_tokens = tokenizer.num_tokens + before = num_tokens + after = before + multiple = args.make_vocab_size_divisible_by * \ + mpu.get_model_parallel_world_size() + while (after % multiple) != 0: + after += 1 + print_rank_0('> padded vocab (size: {}) with {} dummy ' + 'tokens (new size: {})'.format( + before, after - before, after)) + args.vocab_size = after + print("prepare tokenizer done", flush=True) + return tokenizer + diff --git a/training/model_io.py b/training/model_io.py index ee2e8f1be426a0dc4a215d161e583246c387c583..ce434fdfb63396aab4fdcfa3f2b761625e7011f6 100644 --- a/training/model_io.py +++ b/training/model_io.py @@ -121,7 +121,7 @@ def load_checkpoint(model, args): torch.distributed.get_rank(), checkpoint_name)) sd = torch.load(checkpoint_name, map_location='cpu') - assert not args.do_train or args.deepspeed + assert not hasattr(args, 'do_train') or not args.do_train or args.deepspeed if args.deepspeed: module = model.module else: # inference without deepspeed @@ -136,8 +136,10 @@ def load_checkpoint(model, args): raise ValueError(f'Missing keys for inference: {missing_keys}.') else: # new params assert all(name.find('mixins')>=0 for name in missing_keys) + assert args.mode == 'finetune' module.reinit() # initialize mixins - model.optimizer.refresh_fp32_params() # restore fp32 weights from module + if args.mode != 'inference': + model.optimizer.refresh_fp32_params() # restore fp32 weights from module # Iterations. if args.mode == 'finetune':