diff --git a/generation/object_sampling.py b/generation/object_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..a7de80e79c184060420e8dfb001f6fb801611c04 --- /dev/null +++ b/generation/object_sampling.py @@ -0,0 +1,123 @@ +# -*- encoding: utf-8 -*- +''' +@File : autoregressive_sampling.py +@Time : 2021/10/08 15:43:59 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +from .sampling_strategies import BaseStrategy + + + + + +def update_mems(hiddens, mems, max_memory_length): + ''' + hiddens: list (num_layers) of [batch, query_length, 2d] + mems: None or [num_layers, batch, memory_length, 2d] + ''' + if hiddens is None: + return None + hiddens = torch.stack(hiddens) + memory_length = mems.shape[2] if mems is not None else 0 + query_length = hiddens.shape[2] + new_memory_length = min(max_memory_length, memory_length + query_length) + with torch.no_grad(): + if new_memory_length <= query_length: + return hiddens[:, :, -new_memory_length:] + else: + if mems.shape[1] < hiddens.shape[1]: + mems = mems.expand(-1, hiddens.shape[1], -1, -1) + return torch.cat( + (mems[:, :, -new_memory_length + query_length:], hiddens), + dim=2 + ) + + +def filling_sequence_object( + model, + seq, + n_pads, + object_pads, + batch_size, + strategy=BaseStrategy(), + max_memory_length=100000, + log_attention_weights=None +): + ''' + seq: [2, 3, 5, ..., -1(to be generated), -1, ...] + ''' + assert len(seq.shape) == 1 + + # building the initial tokens, attention_mask, and position_ids + context_length = 0 + while seq[context_length] >= 0: + context_length += 1 # [0, context_length-1] are given + assert context_length > 0 + tokens = seq.unsqueeze(0) + + # 和训练的时候不一样,前面没有pad + attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) + attention_mask.tril_() + attention_mask.unsqueeze_(1) + + position_ids = torch.zeros(1, len(seq), dtype=torch.long, device=tokens.device) + + torch.arange(64 - n_pads, out=position_ids[0, n_pads:64], + dtype=torch.long, device=tokens.device) + torch.arange(180 - object_pads, out=position_ids[0, 64 + object_pads:64 + 180]) + # breakpoint() + torch.arange(64 - n_pads, 64 - n_pads + len(seq) - (64 + 180), + out=position_ids[0, 64 + 180:], + dtype=torch.long, device=tokens.device) + # breakpoint() + + tokens = tokens[..., :context_length] + attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16 + # initialize generation + counter = context_length - 1 # Last fixed index is ``counter'' + index = 0 # Next forward starting index, also the length of cache. + mems = None # mems are the first-level citizens here, but we don't assume what is memorized. + + # step-by-step generation + while counter < len(seq) - 1: + # Now, we want to generate seq[counter + 1], + # token[:, index: counter+1] needs forwarding. + + if seq[counter + 1] >= 0: # provided + tokens = torch.cat( + ( + tokens, + seq[counter + 1: counter + 2].expand(tokens.shape[0], 1) + ), dim=1 + ) + counter += 1 + continue + + # forward + if log_attention_weights is not None: + model.log_attention_weights = log_attention_weights[..., index: counter + 1, :counter + 1] # TODO memlen + kw_tensors = {'mems': mems} if mems is not None else {} + logits, *mem_kv = model( + tokens[:, index:], + position_ids[..., index: counter + 1], + attention_mask[..., index: counter + 1, :counter + 1], # TODO memlen + **kw_tensors # if no mems, cannot pass + ) + mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) + counter += 1 + index = counter + # sampling + logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size] + tokens = tokens.expand(batch_size, -1) + tokens, mems = strategy.forward(logits, tokens, mems) + + model.log_attention_weights = None + return tokens \ No newline at end of file diff --git a/inference_object.py b/inference_object.py index 576635af756dbdf36d4a99ecd6edcda5343fda3a..46c623b90da924e9cb69d66abf0118817a96170e 100644 --- a/inference_object.py +++ b/inference_object.py @@ -22,6 +22,7 @@ from training import load_checkpoint, initialize_distributed, set_random_seed, p from tokenization import get_tokenizer from generation.sampling_strategies import BaseStrategy from generation.autoregressive_sampling import filling_sequence +from generation.object_sampling import filling_sequence_object from generation.utils import timed_name, save_multiple_images, generate_continually @@ -48,20 +49,28 @@ def main(args): print('raw text: ', raw_text) raw_text = raw_text.split(' ') objects = raw_text[1:] - seq = tokenizer.parse_query(f"[ROI1] {raw_text[0]}", img_size=args.img_size) + seq1 = tokenizer.parse_query(f"[ROI1] {raw_text[0]}", img_size=args.img_size) + seq2 = [] for i in range(len(objects)//5): - seq.append(tokenizer['POS0']) - seq.extend(objects[i*5:i*5+4] + args.old_token_num) - seq.extend(tokenizer.EncodeAsIds(objects[i*5 + 4])) - seq.extend(tokenizer.parse_query('[BASE] [BOI1] [MASK]*1024', img_size=args.img_size)) - + seq2.append(tokenizer['[POS0]']) + seq2.extend([int(x)+args.old_token_num for x in objects[i*5:i*5+4]]) + seq2.extend(tokenizer.EncodeAsIds(objects[i*5 + 4])) + object_pads = 180 - len(seq2) + seq2 = [tokenizer['[PAD]']] * object_pads + seq2 + seq3 = tokenizer.parse_query('[BASE] [BOI1] [MASK]*1024', img_size=args.img_size) + seq = seq1 + seq2 + seq3 + n_pads = args.new_sequence_length - 1 - len(seq) + seq = ([tokenizer['[PAD]']] * n_pads) + seq + print("len seq is ", len(seq)) if len(seq) > 1271: raise ValueError('text too long.') # calibrate text length - txt_len = seq.index(tokenizer['[BASE]']) + front_len = seq.index(tokenizer['[BASE]']) log_attention_weights = torch.zeros(len(seq), len(seq), device=args.device, dtype=torch.half if args.fp16 else torch.float32) - log_attention_weights[txt_len + 2:, 1:txt_len] = 1.8 if txt_len <= 10 else 1.4 # TODO args + # TODO text attention + # log_attention_weights[front_len + 2:, 1:front_len] = 1.8 if front_len <= 10 else 1.4 # TODO args + # generation seq = torch.cuda.LongTensor(seq, device=args.device) mbz = args.max_inference_batch_size @@ -69,18 +78,21 @@ def main(args): output_list = [] for tim in range(max(args.batch_size // mbz, 1)): output_list.append( - filling_sequence(model, seq.clone(), - batch_size=min(args.batch_size, mbz), - strategy=strategy, - log_attention_weights=log_attention_weights - ) + filling_sequence_object(model, seq.clone(), + n_pads = n_pads, + object_pads = object_pads, + batch_size=min(args.batch_size, mbz), + strategy=strategy, + log_attention_weights=log_attention_weights + ) ) output_tokens = torch.cat(output_list, dim=0) # decoding imgs, txts = [], [] for seq in output_tokens: - txt_len = seq.index(tokenizer['[BASE]']) - seq = seq[txt_len:] + # txt_len = seq.index(tokenizer['[BASE]']) + breakpoint() + seq = seq[-1025:] _, decoded_imgs = tokenizer.DecodeIds(seq.tolist()) imgs.append(decoded_imgs[-1]) # only the last image (target) # save @@ -89,7 +101,7 @@ def main(args): os.makedirs(full_path, exist_ok=True) save_multiple_images(imgs, full_path, False) else: - prefix = raw_text.replace('/', '')[:20] + prefix = raw_text[0].replace('/', '')[:20] full_path = timed_name(prefix, '.jpg', args.output_path) save_multiple_images(imgs, full_path, True) @@ -105,6 +117,6 @@ if __name__ == "__main__": known, args_list = py_parser.parse_known_args() args = get_args(args_list) args = argparse.Namespace(**vars(args), **vars(known)) - + args.layout = [int(x) for x in args.layout.split(',')] with torch.no_grad(): main(args) \ No newline at end of file diff --git a/model/ObjectModel.py b/model/ObjectModel.py index 1000fe1f7594cc44c9cd65e9387e0f03d9b18c37..0adf84a35218da0acca2bc4d5d45ed52b740f773 100644 --- a/model/ObjectModel.py +++ b/model/ObjectModel.py @@ -24,7 +24,7 @@ class ObjectModel(BaseModel): )) self.layout = args.layout - def position_embedding_forward(self, position_ids, *other_tensors): + def position_embedding_forward(self, position_ids, **kw_tensors): # breakpoint() position_text = position_ids[..., :self.layout[0]] position_object = position_ids[..., self.layout[0]:self.layout[1]] @@ -39,10 +39,10 @@ class ObjectModel(BaseModel): ) return position_embeddings - def word_embedding_forward(self, input_ids, *other_tensors): + def word_embedding_forward(self, input_ids, **kw_tensors): return self.mixins[1].word_embeddings(input_ids) - def final_forward(self, logits, *other_tensors): + def final_forward(self, logits, **kw_tensors): logits = copy_to_model_parallel_region(logits) logits = F.linear(logits, self.mixins[1].word_embeddings.weight) return logits @@ -50,7 +50,7 @@ class ObjectModel(BaseModel): @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('ObjectModel', 'Object model configurations') - group.add_argument("--layout", type=str, default='64,246,1270') + group.add_argument("--layout", type=str, default='64,246,1270') #246 = 64 + 182 group.add_argument("--old-token-num", type=int, default=58219) group.add_argument("--additional-token-num", type=int, default=257) group.add_argument("--new-sequence-length", type=int, default=1271) #1089 + 180 + 2 diff --git a/model/cached_object_model.py b/model/cached_object_model.py index 6ef67d76ef63fbc905bc9b7155846066f1b9be99..dc3f797d1dbe65261cfd4bad69e25c62f574753c 100644 --- a/model/cached_object_model.py +++ b/model/cached_object_model.py @@ -17,9 +17,9 @@ class CachedObjectModel(ObjectModel): super().__init__(args, transformer=transformer) self.log_attention_weights = None - def attention_forward(self, hidden_states, mask, *other_tensors, layer_id=None): + def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, **kwargs): attn_module = self.transformer.layers[layer_id].attention - mem = other_tensors[layer_id] if len(other_tensors) > 0 else None + mem = mems[layer_id] if mems is not None else None mixed_raw_layer = attn_module.query_key_value(hidden_states) (mixed_query_layer, @@ -48,3 +48,4 @@ class CachedObjectModel(ObjectModel): new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous() return output, new_mem + diff --git a/pretrain_objectModel.py b/pretrain_objectModel.py index 0428db3b72aa3cfa67e9a8f1b7d6771de5d8ea7a..c1ecd96a73c0ffc3039d674a2fa5bcf4db3f3ea5 100644 --- a/pretrain_objectModel.py +++ b/pretrain_objectModel.py @@ -62,14 +62,15 @@ def get_masks_and_position_ids(data, position_ids = torch.zeros(batch_size, seq_length, dtype=torch.long, device=data.device) + #文字和图片部分用原来的position,object部分用新的从零开始的position for i in range(batch_size): torch.arange(64 - n_pads[i], out=position_ids[i, n_pads[i]:64], - dtype=torch.long, device=data.device) - torch.arange(180-object_pads[i], out=position_ids[i, 64+object_pads:64+180]) + dtype=torch.long, device=data.device)#text + torch.arange(180-object_pads[i], out=position_ids[i, 64+object_pads:64+180])#object # breakpoint() torch.arange(64 - n_pads[i], 64 - n_pads[i] + seq_length - (64+180), out=position_ids[i, 64+180:], - dtype=torch.long, device=data.device) + dtype=torch.long, device=data.device)#image return attention_mask, loss_mask, position_ids @@ -97,7 +98,6 @@ def get_batch(data_iterator, args, timers): labels = tokens_[:, 1:].contiguous() loss_mask = loss_mask[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() - attention_mask = None # Get the masks and postition ids. @@ -112,7 +112,6 @@ def get_batch(data_iterator, args, timers): # Convert if args.fp16: attention_mask = attention_mask.half() - return tokens, labels, loss_mask, attention_mask, position_ids @@ -137,50 +136,51 @@ def forward_step(data_iterator, model, args, timers): def create_dataset_function(path, args): tokenizer = get_tokenizer() - layout = [100,164,1188] + layout = [100,164,1188] #bin文件里的layout names = get_names() tokens = [] for name in names: tokens.append(tokenizer.EncodeAsIds(name)) - # 4 + 4 + 1 一个object需要9个token 20个需要 20*9 = 180个 + + # 4坐标 + 4文字 + 1POS0 一个object需要9个token 最多20个object 20个需要 20*9 = 180个 + # 布局 + # 前64 : [PAD] * n + [ROI] + text + # 64 ~ 64+180 : [PAD] * n + ([POS0] + xmin + ymin + xmax + ymax + class) * m + # 64+180 ~ 最后 : [BASE] [BOI] imageTokens [EOI] def process_fn(row): row = row.astype(np.int64) codes = [row[layout[1]:layout[2]]] text = row[layout[0]:layout[1]] text = text[text > 0][:63] # [ROI] object_tokens = [] - # print(row[:100]) - for i in range(20): + + for i in range(20): #最多20个object object = row[i * 5: (i+1) * 5] if object[0] == -1: break - # print("object", object) object[2] += object[0] object[3] += object[1] object_tokens.append(tokenizer['[POS0]']) - object_tokens.extend([object[j] + args.old_token_num for j in range(4)]) + object_tokens.extend([object[j] + args.old_token_num for j in range(4)]) #坐标的token是原来的token数量+坐标值 object_tokens.extend(tokens[object[4]]) + object_tokens = np.array(object_tokens) - # print(object_tokens) - object_pad = 180 - object_tokens.shape[-1] + object_pad = 180 - object_tokens.shape[-1] #object部分补到180 object_tokens = np.concatenate([ np.array([tokenizer['[PAD]']] * object_pad, dtype=np.int64), object_tokens ], axis = 0) - # 180 text_object = np.concatenate([text, np.array(object_tokens, dtype=np.int64)], axis=0) - # print(len(text), len(text_object)) - # print(len(codes[0])) merged = TextCodeTemplate(text_object, codes[0], tokenizer) - # print(len(merged), len(text)) - n_pad = args.new_sequence_length - len(merged) + + n_pad = args.new_sequence_length - len(merged) #整体补到最大长度 parts = [ np.array([tokenizer['[PAD]']] * n_pad, dtype=np.int64), merged ] ret = np.concatenate(parts, axis=0) return {'text': ret, - 'loss_mask': np.array([0] * n_pad + [1] * (len(text) + 1) + [0] * object_pad + [1] * (182 - object_pad + 1025)), + 'loss_mask': np.array([0] * n_pad + [1] * (len(text) + 1) + [0] * object_pad + [1] * (182 - object_pad + 1025)), #两个PAD的部分loss mask 为0 'object_pad':object_pad, 'n_pad':n_pad } diff --git a/scripts/text2image_object.sh b/scripts/text2image_object.sh index 296b9daab9a19eddc1a35c68bcab9fa0141abe2b..2b5429962aa7648714d376e659171fe32f168dec 100755 --- a/scripts/text2image_object.sh +++ b/scripts/text2image_object.sh @@ -1,6 +1,6 @@ #!/bin/bash -CHECKPOINT_PATH=checkpoints/finetune2-object-test10-24-12-19 +CHECKPOINT_PATH=checkpoints/finetune2-object-test10-24-12-29 NLAYERS=48 NHIDDEN=2560 NATT=40