Skip to content
Snippets Groups Projects
Commit 154fc361 authored by Ming Ding's avatar Ming Ding
Browse files

test cogview1 generation, pass

parent 7c5a12da
Branches
Tags
No related merge requests found
Showing with 376 additions and 218 deletions
......@@ -157,27 +157,15 @@ def add_text_generate_args(parser):
group.add_argument("--temperature", type=float, default=1.0)
group.add_argument("--top_p", type=float, default=0.0)
group.add_argument("--top_k", type=int, default=0)
# group.add_argument("--out-seq-length", type=int, default=256)
group.add_argument("--generation-task", type=str,
default='text2image',
choices=['text2image',
'image2text',
'super-resolution',
'low-level super-resolution',
'post-selection',
'raw',
'cuda-2d generation'
],
help='what type of inference task to use')
group.add_argument("--out-seq-length", type=int, default=256)
group.add_argument('--input-source', type=str, default='interactive',
help='what input mode to use, interactive or path')
group.add_argument('--output-path', type=str, default='./samples',
help='path to place the generated samples')
group.add_argument('--debug', action='store_true',
help='Debug will merge all outputs.')
group.add_argument('--with-id', action='store_true',
help='If each line is prepended with an id.')
group.add_argument('--max-inference-batch-size', type=int, default=12)
group.add_argument('--device', type=int, default=0)
return parser
......@@ -218,7 +206,6 @@ def add_generation_api_args(parser):
group.add_argument('--input_rec_path', default='input/')
group.add_argument('--check_mode', default='code')
group.add_argument('--time_interval', default=10)
group.add_argument('--device', default=None)
return parser
......
from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score
from .magnify import magnify
from .cuda_2d_sampling import filling_sequence_cuda_2d
\ No newline at end of file
......@@ -37,6 +37,8 @@ def update_mems(hiddens, mems, max_memory_length):
if new_memory_length <= query_length:
new_mems.append(hiddens[i][:, -new_memory_length:])
else:
if mems[i].shape[0] < hiddens[i].shape[0]:
mems[i] = mems[i].expand(hiddens[i].shape[0], *mems[i].shape[1:])
new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
return new_mems
......@@ -45,8 +47,9 @@ def filling_sequence(
model,
seq,
batch_size,
strategy=BaseStrategy(),
max_memory_length=100000,
strategy=BaseStrategy()
log_attention_weights=None
):
'''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
......@@ -60,12 +63,12 @@ def filling_sequence(
assert context_length > 0
tokens, attention_mask, position_ids = get_masks_and_position_ids(seq)
tokens = tokens[..., :context_length]
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
# initialize generation
counter = context_length - 1 # Last fixed index is ``counter''
index = 0 # Next forward starting index, also the length of cache.
mems = [] # mems are the first-level citizens here, but we don't assume what is memorized.
# step-by-step generation
while counter < len(seq) - 1:
# Now, we want to generate seq[counter + 1],
......@@ -82,18 +85,20 @@ def filling_sequence(
continue
# forward
model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
logits, *mem_kv = model(
tokens[:, index:],
position_ids[..., index: counter+1],
attention_mask[..., index: counter+1, :counter+1], # TODO mem
attention_mask[..., index: counter+1, :counter+1], # TODO memlen
*mems
)
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
index = counter
# sampling
logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
tokens = tokens.expand(batch_size, -1)
tokens, mems = strategy.forward(logits, tokens, mems)
model.log_attention_weights = None
return tokens
\ No newline at end of file
......@@ -12,7 +12,7 @@ import sys
import math
import random
import torch
from .sampling_strategies import BaseStrategy
from .sampling_strategies import IterativeEntfilterStrategy
def filling_sequence(
model,
......@@ -20,11 +20,15 @@ def filling_sequence(
seq1,
warmup_steps=3,
block_hw=(4, 4),
strategy=BaseStrategy(topk=10)
strategy=IterativeEntfilterStrategy(topk=10)
):
'''
seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1]
4095 {layout[2]} final_token
4095 {layout[2]} final_token.
Attention:
The sampling temperature are changing, temporally we hard code them here.
The temperature in the strategy is not used.
'''
assert hasattr(model, 'layout')
layout = model.layout
......@@ -46,7 +50,7 @@ def filling_sequence(
assert seq.shape[1] == layout[-1] + 1
# build initial tokens, attention_mask, and position_ids
tokens = seq[:, :-1].clone()
tokens = seq.clone()
attention_mask = torch.ones(layout[1], layout[1]).tril().to(device)
attention_mask[n_pad:, :n_pad] = 0
position_ids = torch.cat((
......@@ -54,104 +58,32 @@ def filling_sequence(
torch.arange(0, layout[1] - n_pad),
torch.arange(0, layout[2]-layout[1]))).to(device)
# iterative refining
# prepare for interation
unfixed = (tokens < 0)
unfixed[:, -4096] = True
ll, rr = block_hw
edge_len = int(math.sqrt(layout[-1] - layout[-2]) + 1e-4)
num_steps = warmup_steps + ll + rr - 2
for step_cnt in range(num_steps):
logits, *_dump = model(tokens, position_ids, attention_mask)
# warmup
real_topk = 10
warmup_steps = 3
iterative_step= warmup_steps + 6
# interative refining
for step_cnt in range(1, num_steps+1):
logits, *_dump = model(tokens[:,:-1], position_ids, attention_mask)
if step_cnt <= warmup_steps:
real_temp = 0.1
elif step_cnt == warmup_steps + 1:
real_temp = 0.55
elif step_cnt > warmup_steps + 1:
real_temp = 0.45
# if 5 < step_cnt:
# real_topk = 200
# sampling
for invalid_slice in invalid_slices: # forbide to generate other tokens
logits[..., invalid_slice] = -float('Inf')
assert args.top_k > 0
# probs0 = F.softmax(logits/real_temp, dim=-1)
topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1)
ent = -(topraw * topraw.log()).sum(dim=-1)
# topsum = topraw.sum(dim=-1)
if step_cnt > warmup_steps:
# import pdb;pdb.set_trace()
real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6
# import pdb;pdb.set_trace()
tokens = strategy.forward(logits, tokens, real_temp)
else:
real_temp2 = real_temp
# import pdb;pdb.set_trace()
probs = F.softmax(logits/real_temp2, dim=-1)
tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1)
prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
edge_idx = tk_idx[:, :, -1:]
edge_value = tk_value[:, :, -1:]
edge_mask = probs.gather(dim=-1, index=prev) < edge_value
prev[edge_mask] = edge_idx[edge_mask]
prev.squeeze_(-1)
# tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1])
# prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1)
# prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1)
# update unfixed
choice = 1
if choice == 0 and warmup_steps < step_cnt:
mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
# import pdb;pdb.set_trace()
dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
new_fixed = unfixed.clone()
moved_new_fixed = new_fixed[:, 2:]
moved_new_fixed &= dprob
moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
# moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
# moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
elif choice == 1 and warmup_steps < step_cnt:
new_fixed = unfixed & False
ll, rr = 4, 4
real_temp = 1.05
new_tokens = strategy.forward(
logits, tokens, real_temp,
entfilter=1.3,
filter_topk=5,
temperature2=0.6
)
tokens[unfixed] = new_tokens[unfixed]
# fixed tokens (update unfixed)
for x in range(min(ll, step_cnt - warmup_steps)):
y = step_cnt - warmup_steps - x - 1
if y < rr:
print(x,y)
new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True
new_fixed &= unfixed
else:
new_fixed = unfixed & False # TODO
new_fixed[:, -1] = True
# with open(f'bed{step_cnt}.txt', 'w') as fout:
# for i, prob in enumerate(topraw[0, -4096:]):
# s = ' '.join([str(x) for x in prob.tolist()])
# fout.write(f'{i} {s}\n')
unfixed &= new_fixed.logical_not()
# update seq and tokens
seq[new_fixed] = prev[new_fixed[:, 1:]]
tokens = seq[:, :-1].clone()
tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]]
if step_cnt == iterative_step:
seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
n_unfixed = unfixed.sum(dim=-1).tolist()
print(f'Exit with {n_unfixed} unfixed tokens.')
break
if args.debug:
from torchvision.utils import save_image
seqt = seq.clone()
seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step
imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt])
if args.debug:
imgs = torch.cat(imgs, dim=0)
save_image(imgs, f'steps{device}.jpg', normalize=True)
model.module.transformer.max_memory_length = args.max_memory_length
return seq
\ No newline at end of file
unfixed[..., -(layout[-1] - layout[-2]):].view(
batch_size, edge_len//ll, ll, edge_len//rr, rr)[:, :, x, :, y] = False
return tokens
\ No newline at end of file
......@@ -116,19 +116,20 @@ def filling_sequence_cuda_2d(
# update unfixed
choice = 1
if choice == 0 and warmup_steps < step_cnt:
mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
# import pdb;pdb.set_trace()
dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
# mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2]))
# # import pdb;pdb.set_trace()
# dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:])
new_fixed = unfixed.clone()
moved_new_fixed = new_fixed[:, 2:]
moved_new_fixed &= dprob
moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
# moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
# moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
# new_fixed = unfixed.clone()
# moved_new_fixed = new_fixed[:, 2:]
# moved_new_fixed &= dprob
# moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not()
# moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not()
# # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not()
# moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not()
# moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not()
# # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not()
pass
elif choice == 1 and warmup_steps < step_cnt:
new_fixed = unfixed & False
ll, rr = 4, 4
......
from .base_strategy import BaseStrategy
\ No newline at end of file
from .base_strategy import BaseStrategy
from .iterative_entfilter_strategy import IterativeEntfilterStrategy
\ No newline at end of file
......@@ -20,27 +20,20 @@ def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
return logits
class BaseStrategy:
def __init__(self, invalid_slices=[], temperature=1., topk=200, debias=False):
def __init__(self, invalid_slices=[], temperature=1., topk=200, eps=1e-4):
self.invalid_slices = invalid_slices
self.temperature = temperature
self.topk = topk
self.debias = debias
self.eps = eps
def forward(self, logits, tokens, mems, temperature=None):
if temperature is None:
temperature = self.temperature
logits = logits / temperature
for invalid_slice in self.invalid_slices:
logits[..., invalid_slice] = -float('Inf')
if self.debias:
probs = F.softmax(logits, dim=-1)
tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
pred = torch.multinomial(probs, num_samples=1)
for j in range(0, pred.shape[0]):
if probs[j, pred[j,-1]] < tk_value[j, -1]:
pred[j, -1] = tk_idx[j, torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))] # 100 is the last N as outlier, which is chosen casually
else:
logits = top_k_logits_(logits)
probs = F.softmax(logits, dim=-1)
pred = torch.multinomial(probs, num_samples=1)
logits[..., invalid_slice] = -65504
logits = top_k_logits_(logits, self.topk)
probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
pred = torch.multinomial(probs, num_samples=1)
tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
return tokens, mems
# -*- encoding: utf-8 -*-
'''
@File : iterative_entfilter_strategy.py
@Time : 2021/10/09 14:32:29
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import torch.nn.functional as F
def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')):
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
return logits
class IterativeEntfilterStrategy:
def __init__(self, invalid_slices=[], temperature=1., topk=10):
self.invalid_slices = invalid_slices
self.temperature = temperature
self.topk = topk
def forward(self, logits, tokens, temperature=None, entfilter=None, filter_topk=5, temperature2=None):
# In interative strategy, logits are of shape [batch_size, seq_length, hidden_size]
if temperature is None:
temperature = self.temperature
# check entropy filter
if entfilter is not None:
assert temperature2 is not None
topraw = (torch.topk(logits, filter_topk, dim=-1)[0]).softmax(dim=-1)
ent = -(topraw * topraw.log()).sum(dim=-1) # [batch_size, seq_length]
temperature = torch.tensor([[[temperature - temperature2]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > entfilter).unsqueeze(-1) + temperature2
logits = logits / temperature
for invalid_slice in self.invalid_slices:
logits[..., invalid_slice] = -float('Inf')
# debiased topk
probs = F.softmax(logits, dim=-1)
tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1)
pred = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1)
edge_idx = tk_idx[:, :, -1:]
edge_value = tk_value[:, :, -1:]
edge_mask = probs.gather(dim=-1, index=pred) < edge_value
pred[edge_mask] = edge_idx[edge_mask] # replace outliers as the "filter_topk"-th token
pred.squeeze_(-1) # [batch_size, seq_length]
assert tokens.shape[1] == pred.shape[1] + 1
tokens = torch.cat((tokens[:, :1], pred), dim=1)
return tokens
\ No newline at end of file
# -*- encoding: utf-8 -*-
'''
@File : utils.py
@Time : 2021/10/09 17:18:26
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import time
import stat
from datetime import datetime
from torchvision.utils import save_image
import torch.distributed as dist
def timed_name(prefix, suffix=None, path=None):
return os.path.join(
path,
f"{prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}{suffix}"
)
def save_multiple_images(imgs, path, debug=True):
# imgs: list of tensor images
if debug:
imgs = torch.cat(imgs, dim=0)
print("\nSave to: ", path, flush=True)
save_image(imgs, path, normalize=True)
else:
print("\nSave to: ", path, flush=True)
for i in range(len(imgs)):
save_image(imgs[i], os.path.join(path, f'{i}.jpg'), normalize=True)
os.chmod(os.path.join(path,f'{i}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
save_image(torch.cat(imgs, dim=0), os.path.join(path,f'concat.jpg'), normalize=True)
os.chmod(os.path.join(path,f'concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
def generate_continually(func, input_source='interactive'):
if input_source == 'interactive':
while True:
raw_text = input("\nPlease Input Query (stop to exit) >>> ")
raw_text = raw_text.strip()
if not raw_text:
print('Query should not be empty!')
continue
if raw_text == "stop":
return
try:
start_time = time.time()
func(raw_text)
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
except (ValueError, FileNotFoundError) as e:
print(e)
continue
else:
with open(input_source, 'r') as fin:
inputs = fin.readlines()
err_linenos = []
for line_no, raw_text in enumerate(inputs):
if line_no % dist.get_world_size() != dist.get_rank():
continue
rk = dist.get_rank()
print(f'Working on No. {line_no} on {rk}... ')
raw_text = raw_text.strip()
if len(raw_text) == 0:
continue
try:
start_time = time.time()
func(raw_text)
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
except (ValueError, FileNotFoundError) as e:
err_linenos.append(line_no)
continue
print(err_linenos)
# -*- encoding: utf-8 -*-
'''
@File : inference_cogview.py
@Time : 2021/10/09 19:41:58
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import argparse
from arguments import get_args
from model.cached_autoregressive_model import CachedAutoregressiveModel
from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
from tokenization import get_tokenizer
from generation.sampling_strategies import BaseStrategy
from generation.autoregressive_sampling import filling_sequence
from generation.utils import timed_name, save_multiple_images, generate_continually
def main(args):
initialize_distributed(args)
tokenizer = prepare_tokenizer(args)
# build model
model = CachedAutoregressiveModel(args)
if args.fp16:
model = model.half()
model = model.to(args.device)
load_checkpoint(model, args)
set_random_seed(args.seed)
# define function for each query
query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' if not args.full_query else '{}'
invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
strategy = BaseStrategy(invalid_slices,
temperature=args.temperature, topk=args.top_k)
def process(raw_text):
if args.with_id:
query_id, raw_text = raw_text.split()
print('raw text: ', raw_text)
text = query_template.format(raw_text)
seq = tokenizer.parse_query(text, img_size=args.img_size)
if len(seq) > 1088:
raise ValueError('text too long.')
# calibrate text length
txt_len = seq.index(tokenizer['[BASE]'])
log_attention_weights = torch.zeros(len(seq), len(seq),
device=args.device, dtype=torch.half if args.fp16 else torch.float32)
log_attention_weights[txt_len+2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args
# generation
seq = torch.cuda.LongTensor(seq, device=args.device)
mbz = args.max_inference_batch_size
assert args.batch_size < mbz or args.batch_size % mbz == 0
output_list = []
for tim in range(max(args.batch_size // mbz, 1)):
output_list.append(
filling_sequence(model, seq.clone(),
batch_size=min(args.batch_size, mbz),
strategy=strategy,
log_attention_weights=log_attention_weights
)
)
output_tokens = torch.cat(output_list, dim=0)
# decoding
imgs, txts = [], []
for seq in output_tokens:
decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
imgs.append(decoded_imgs[-1]) # only the last image (target)
# save
if args.with_id:
full_path = os.path.join(args.output_path, query_id)
os.makedirs(full_path, exist_ok=True)
save_multiple_images(imgs, full_path, False)
else:
prefix = raw_text.replace('/', '')[:20]
full_path = timed_name(prefix, '.jpg', args.output_path)
save_multiple_images(imgs, full_path, True)
os.makedirs(args.output_path, exist_ok=True)
generate_continually(process, args.input_source)
if __name__ == "__main__":
py_parser = argparse.ArgumentParser(add_help=False)
py_parser.add_argument('--full-query', action='store_true')
py_parser.add_argument('--img-size', type=int, default=256)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
with torch.no_grad():
main(args)
\ No newline at end of file
......@@ -37,14 +37,15 @@ class CachedAutoregressiveModel(BaseModel):
mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1)
# same as training
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, dropout_fn=None, log_attention_weights=self.log_attention_weights)
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=self.log_attention_weights)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
output = attn_module.dense(context_layer)
# new mem this layer
new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
......
#!/bin/bash
CHECKPOINT_PATH=pretrained/cogview/cogview-base
NLAYERS=48
NHIDDEN=2560
NATT=40
MAXSEQLEN=1089
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
MPSIZE=1
#SAMPLING ARGS
TEMP=1.03
TOPK=200
script_path=$(realpath $0)
script_dir=$(dirname $script_path)
MASTER_PORT=${MASTER_PORT} python inference_cogview.py \
--tokenizer-type cogview \
--img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
--mode inference \
--distributed-backend nccl \
--max-sequence-length 1089 \
--sandwich-ln \
--fp16 \
--model-parallel-size $MPSIZE \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--load $CHECKPOINT_PATH \
--num-attention-heads $NATT \
--temperature $TEMP \
--top_k $TOPK \
--sandwich-ln \
--input-source ./input.txt \
--output-path samples_text2image \
--batch-size 8 \
--max-inference-batch-size 8 \
--device 0 \
$@
from .deepspeed_training import initialize_distributed, set_random_seed, prepare_tokenizer
from .model_io import load_checkpoint
\ No newline at end of file
......@@ -59,14 +59,13 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
else:
args.experiment_name = args.experiment_name + datetime.now().strftime("%m-%d-%H-%M")
# Pytorch distributed.
# Pytorch distributed. must before seed
initialize_distributed(args)
set_random_seed(args.seed) # Random seeds for reproducability.
# init tokenizer
tokenizer = get_tokenizer(args)
prepare_tokenizer(args) # args.vocab_size is set.
# Data stuff.
train_data, val_data, test_data, args.vocab_size = get_train_val_test_data(args, hooks['create_dataset_function'])
train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'])
# Model, optimizer, and learning rate.
model, optimizer = setup_model_and_optimizer(args, model_cls)
......@@ -514,72 +513,39 @@ def initialize_distributed(args):
# Optional DeepSpeed Activation Checkpointing Features
if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
set_deepspeed_activation_checkpointing(args)
set_deepspeed_activation_checkpointing(args) # TODO manual model-parallel seed
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)
def get_train_val_test_data(args, create_dataset_function):
"""Load the data on rank zero and boradcast number of tokens to all GPUS."""
(train_data, val_data, test_data) = (None, None, None)
# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0:
train_data, val_data, test_data = make_loaders(args, create_dataset_function)
num_tokens = get_tokenizer().num_tokens
before = num_tokens
after = before
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
before, after - before, after))
token_counts = torch.cuda.LongTensor(
[after, int(args.do_train), int(args.do_valid), int(args.do_test)])
else:
token_counts = torch.cuda.LongTensor([0, 0, 0, 0])
# Broadcast num tokens.
torch.distributed.broadcast(token_counts,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
num_tokens = token_counts[0].item()
args.do_train = token_counts[1].item()
args.do_valid = token_counts[2].item()
args.do_test = token_counts[3].item()
return train_data, val_data, test_data, num_tokens
def see_memory_usage(message, force=False):
if not force:
return
dist.barrier()
if dist.get_rank() == 0:
print(message)
print("Memory Allocated ", torch.cuda.memory_allocated()/(1024*1024*1024), "GigaBytes")
print("Max Memory Allocated ", torch.cuda.max_memory_allocated()/(1024*1024*1024), "GigaBytes")
print("Cache Allocated ", torch.cuda.memory_cached()/(1024*1024*1024), "GigaBytes")
print("Max cache Allocated ", torch.cuda.max_memory_cached()/(1024*1024*1024), "GigaBytes")
print(" ")
def seed_torch(seed=1029):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
torch.backends.cuda.matmul.allow_tf32 = False
if hasattr(mpu, 'model_parallel_cuda_manual_seed'):
mpu.model_parallel_cuda_manual_seed(seed)
def prepare_tokenizer(args):
tokenizer = get_tokenizer(args)
num_tokens = tokenizer.num_tokens
before = num_tokens
after = before
multiple = args.make_vocab_size_divisible_by * \
mpu.get_model_parallel_world_size()
while (after % multiple) != 0:
after += 1
print_rank_0('> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'.format(
before, after - before, after))
args.vocab_size = after
print("prepare tokenizer done", flush=True)
return tokenizer
......@@ -121,7 +121,7 @@ def load_checkpoint(model, args):
torch.distributed.get_rank(), checkpoint_name))
sd = torch.load(checkpoint_name, map_location='cpu')
assert not args.do_train or args.deepspeed
assert not hasattr(args, 'do_train') or not args.do_train or args.deepspeed
if args.deepspeed:
module = model.module
else: # inference without deepspeed
......@@ -136,8 +136,10 @@ def load_checkpoint(model, args):
raise ValueError(f'Missing keys for inference: {missing_keys}.')
else: # new params
assert all(name.find('mixins')>=0 for name in missing_keys)
assert args.mode == 'finetune'
module.reinit() # initialize mixins
model.optimizer.refresh_fp32_params() # restore fp32 weights from module
if args.mode != 'inference':
model.optimizer.refresh_fp32_params() # restore fp32 weights from module
# Iterations.
if args.mode == 'finetune':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment