From 7c5a12da328f6bd44ff98cdc9c2aa392db516fe6 Mon Sep 17 00:00:00 2001 From: Ming Ding <dm_thu@qq.com> Date: Fri, 8 Oct 2021 18:15:08 +0000 Subject: [PATCH] tmp finish naive ar sampling --- generation/__init__.py | 1 - generation/autoregressive_sampling.py | 99 +++++++++++ generation/cuda2d_sampling.py | 157 ++++++++++++++++++ generation/local_sampling.py | 147 ---------------- generation/sampling_strategies/__init__.py | 1 + .../sampling_strategies/base_strategy.py | 46 +++++ model/cached_autoregressive_model.py | 3 +- model/mixins.py | 4 +- mpu/transformer.py | 5 +- pretrain_cogview2.py | 2 +- scripts/finetune_into_cogview2.sh | 59 +++++++ training/model_io.py | 2 +- 12 files changed, 370 insertions(+), 156 deletions(-) create mode 100644 generation/autoregressive_sampling.py create mode 100644 generation/cuda2d_sampling.py delete mode 100644 generation/local_sampling.py create mode 100644 generation/sampling_strategies/__init__.py create mode 100644 generation/sampling_strategies/base_strategy.py create mode 100755 scripts/finetune_into_cogview2.sh diff --git a/generation/__init__.py b/generation/__init__.py index 94f6ea1..90c43c6 100755 --- a/generation/__init__.py +++ b/generation/__init__.py @@ -1,4 +1,3 @@ from .sampling import get_batch, filling_sequence, add_interlacing_beam_marks, inverse_prompt_score from .magnify import magnify -from .local_sampling import filling_sequence_local from .cuda_2d_sampling import filling_sequence_cuda_2d \ No newline at end of file diff --git a/generation/autoregressive_sampling.py b/generation/autoregressive_sampling.py new file mode 100644 index 0000000..f1a047c --- /dev/null +++ b/generation/autoregressive_sampling.py @@ -0,0 +1,99 @@ +# -*- encoding: utf-8 -*- +''' +@File : autoregressive_sampling.py +@Time : 2021/10/08 15:43:59 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +from .sampling_strategies import BaseStrategy + +def get_masks_and_position_ids(seq): + tokens = seq.unsqueeze(0) + + attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) + attention_mask.tril_() + attention_mask.unsqueeze_(1) + + position_ids = torch.arange(len(seq), dtype=torch.long, device=tokens.device) + position_ids = position_ids.unsqueeze(0) + return tokens, attention_mask, position_ids + +def update_mems(hiddens, mems, max_memory_length): + if hiddens is None: + return [] + memory_length = mems[0].size(1) if mems else 0 + query_length = hiddens[0].size(1) + new_memory_length = min(max_memory_length, memory_length + query_length) + new_mems = [] + with torch.no_grad(): + for i in range(len(hiddens)): + if new_memory_length <= query_length: + new_mems.append(hiddens[i][:, -new_memory_length:]) + else: + new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1)) + return new_mems + + +def filling_sequence( + model, + seq, + batch_size, + max_memory_length=100000, + strategy=BaseStrategy() + ): + ''' + seq: [2, 3, 5, ..., -1(to be generated), -1, ...] + ''' + assert len(seq.shape) == 1 + + # building the initial tokens, attention_mask, and position_ids + context_length = 0 + while seq[context_length] >= 0: + context_length += 1 # [0, context_length-1] are given + assert context_length > 0 + tokens, attention_mask, position_ids = get_masks_and_position_ids(seq) + tokens = tokens[..., :context_length] + + # initialize generation + counter = context_length - 1 # Last fixed index is ``counter'' + index = 0 # Next forward starting index, also the length of cache. + mems = [] # mems are the first-level citizens here, but we don't assume what is memorized. + + # step-by-step generation + while counter < len(seq) - 1: + # Now, we want to generate seq[counter + 1], + # token[:, index: counter+1] needs forwarding. + + if seq[counter + 1] >= 0: # provided + tokens = torch.cat( + ( + tokens, + seq[counter+1: counter+2].expand(tokens.shape[0], 1) + ), dim=1 + ) + counter += 1 + continue + + # forward + logits, *mem_kv = model( + tokens[:, index:], + position_ids[..., index: counter+1], + attention_mask[..., index: counter+1, :counter+1], # TODO mem + *mems + ) + mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) + counter += 1 + index = counter + + # sampling + logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size] + tokens, mems = strategy.forward(logits, tokens, mems) + + return tokens \ No newline at end of file diff --git a/generation/cuda2d_sampling.py b/generation/cuda2d_sampling.py new file mode 100644 index 0000000..bd75552 --- /dev/null +++ b/generation/cuda2d_sampling.py @@ -0,0 +1,157 @@ +# -*- encoding: utf-8 -*- +''' +@File : cuda2d_sampling.py +@Time : 2021/10/09 00:46:04 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +from .sampling_strategies import BaseStrategy + +def filling_sequence( + model, + seq0, + seq1, + warmup_steps=3, + block_hw=(4, 4), + strategy=BaseStrategy(topk=10) + ): + ''' + seq: [PAD]... [ROI1] text ... [BOI1] {layout[0]} 1024 {layout[1]} [EOI1] + 4095 {layout[2]} final_token + ''' + assert hasattr(model, 'layout') + layout = model.layout + assert len(seq0.shape) == 2 and len(seq1.shape) == 2 \ + and seq0.shape[0] == seq1.shape[0] + assert len(layout) == 3 + assert seq1.shape[1] == layout[-1] - layout[-2] + assert (seq1 >= 0).all() and (seq0 >= 0).all() + device = seq0.device + + # concat and pad sequences + batch_size = seq0.shape[0] + n_pad = layout[1] + 1 - len(seq0) # +1 for [EOI1] + assert n_pad > 0, "You should truncate long input before filling." + seq = torch.cat(( + torch.tensor([0]*n_pad, device=device, dtype=seq0.dtype) + .unsqueeze(0).expand(batch_size, n_pad), + seq0, seq1), dim=1) # [b, layout[-1]+1] + assert seq.shape[1] == layout[-1] + 1 + + # build initial tokens, attention_mask, and position_ids + tokens = seq[:, :-1].clone() + attention_mask = torch.ones(layout[1], layout[1]).tril().to(device) + attention_mask[n_pad:, :n_pad] = 0 + position_ids = torch.cat(( + torch.zeros(n_pad, dtype=torch.long), + torch.arange(0, layout[1] - n_pad), + torch.arange(0, layout[2]-layout[1]))).to(device) + + # iterative refining + ll, rr = block_hw + num_steps = warmup_steps + ll + rr - 2 + for step_cnt in range(num_steps): + logits, *_dump = model(tokens, position_ids, attention_mask) + + # warmup + real_topk = 10 + warmup_steps = 3 + iterative_step= warmup_steps + 6 + if step_cnt <= warmup_steps: + real_temp = 0.1 + elif step_cnt == warmup_steps + 1: + real_temp = 0.55 + elif step_cnt > warmup_steps + 1: + real_temp = 0.45 + # if 5 < step_cnt: + # real_topk = 200 + # sampling + for invalid_slice in invalid_slices: # forbide to generate other tokens + logits[..., invalid_slice] = -float('Inf') + assert args.top_k > 0 + + # probs0 = F.softmax(logits/real_temp, dim=-1) + topraw = (torch.topk(logits, 5, dim=-1)[0]).softmax(dim=-1) + ent = -(topraw * topraw.log()).sum(dim=-1) + # topsum = topraw.sum(dim=-1) + if step_cnt > warmup_steps: + # import pdb;pdb.set_trace() + real_temp2 = torch.tensor([[[real_temp]]], device=logits.device).expand(*logits.shape[:2], 1) * (ent > 1.3).unsqueeze(-1) + 0.6 + # import pdb;pdb.set_trace() + else: + real_temp2 = real_temp + # import pdb;pdb.set_trace() + probs = F.softmax(logits/real_temp2, dim=-1) + tk_value, tk_idx = torch.topk(probs, real_topk, dim=-1) + prev = torch.multinomial(probs.view(-1, logits.shape[-1]), num_samples=1).view(*logits.shape[:2], 1) + edge_idx = tk_idx[:, :, -1:] + edge_value = tk_value[:, :, -1:] + edge_mask = probs.gather(dim=-1, index=prev) < edge_value + prev[edge_mask] = edge_idx[edge_mask] + prev.squeeze_(-1) + # tk_probs = (tk_value / real_temp).softmax(dim=-1).view(-1, tk_value.shape[-1]) + # prev = torch.multinomial(tk_probs, num_samples=1).view(*(tk_value.shape[:2]),1) + # prev = torch.gather(tk_idx, dim=-1, index=prev).squeeze(-1) + # update unfixed + choice = 1 + if choice == 0 and warmup_steps < step_cnt: + mprob = probs.max(dim=-1)[0].view(*(tk_value.shape[:2])) + # import pdb;pdb.set_trace() + dprob = mprob[:, 1:] < mprob[:, args.layout[1]:].topk(300, dim=-1, largest=False)[0][:,-1].unsqueeze(-1).expand_as(mprob[:, 1:]) + + new_fixed = unfixed.clone() + moved_new_fixed = new_fixed[:, 2:] + moved_new_fixed &= dprob + moved_new_fixed[:, 1:] &= dprob[:, :-1].logical_not() | unfixed[:, 2:-1].logical_not() + moved_new_fixed[:, 2:] &= dprob[:, :-2].logical_not() | unfixed[:, 2:-2].logical_not() + # moved_new_fixed[:, 3:] &= dprob[:, :-3].logical_not() | unfixed[:, 2:-3].logical_not() + moved_new_fixed[:, 64:] &= dprob[:, :-64].logical_not() | unfixed[:, 2:-64].logical_not() + moved_new_fixed[:, 65:] &= dprob[:, :-65].logical_not() | unfixed[:, 2:-65].logical_not() + # moved_new_fixed[:, 66:] &= dprob[:, :-66].logical_not() | unfixed[:, 2:-66].logical_not() + elif choice == 1 and warmup_steps < step_cnt: + new_fixed = unfixed & False + ll, rr = 4, 4 + for x in range(min(ll, step_cnt - warmup_steps)): + y = step_cnt - warmup_steps - x - 1 + if y < rr: + print(x,y) + new_fixed[..., -4096:].view(batch_size, 64//ll, ll, 64//rr, rr)[:, :, x, :, y] = True + new_fixed &= unfixed + else: + new_fixed = unfixed & False # TODO + new_fixed[:, -1] = True + + # with open(f'bed{step_cnt}.txt', 'w') as fout: + # for i, prob in enumerate(topraw[0, -4096:]): + # s = ' '.join([str(x) for x in prob.tolist()]) + # fout.write(f'{i} {s}\n') + + unfixed &= new_fixed.logical_not() + # update seq and tokens + seq[new_fixed] = prev[new_fixed[:, 1:]] + tokens = seq[:, :-1].clone() + tokens[:,1:][unfixed[:, 1:-1]] = prev[:, :-1][unfixed[:, 1:-1]] + + if step_cnt == iterative_step: + seq[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step + n_unfixed = unfixed.sum(dim=-1).tolist() + print(f'Exit with {n_unfixed} unfixed tokens.') + break + if args.debug: + from torchvision.utils import save_image + seqt = seq.clone() + seqt[:, :-1][unfixed[:, :-1]] = tokens[unfixed[:, :-1]] # if reach iterative_step + imgs.extend([tokenizer.img_tokenizer.DecodeIds(s[-4096:]) for s in seqt]) + if args.debug: + imgs = torch.cat(imgs, dim=0) + save_image(imgs, f'steps{device}.jpg', normalize=True) + model.module.transformer.max_memory_length = args.max_memory_length + + return seq \ No newline at end of file diff --git a/generation/local_sampling.py b/generation/local_sampling.py deleted file mode 100644 index 9fba71a..0000000 --- a/generation/local_sampling.py +++ /dev/null @@ -1,147 +0,0 @@ -from .sampling import * -import math -import sys -from copy import deepcopy - -def make_local_mask(sparse_config): - layout = sparse_config.layout - k1, k2 = sparse_config.kernel_size, sparse_config.kernel_size2 - k1h = k1*2-1 - h = w = int(math.sqrt(layout[2] - layout[1]) + 1e-3) - m = torch.zeros(layout[-1]+1, layout[-1], dtype=torch.bool, device='cuda') - for i in range(layout[1]): - m[i, :i] = True - m[layout[1]:, :layout[0]] = True - for i in tqdm(range(layout[1], layout[2])): - # m[i, layout[1]:i] = True - x = (i - layout[1]) // w - y = (i - layout[1]) % w - lx = max(0, x - k1h // 2) - ly = max(0, y - k1 // 2) - rx = min(h-1, x + k1h // 2) - ry = min(w-1, y + k1 // 2) - m[i, layout[1]:layout[2]].view(h, w)[lx:x, ly:ry+1] = True - m[i, layout[1]:layout[2]].view(h, w)[x, ly:y+1] = True - m[i, i] = False - for i in tqdm(range(layout[2], layout[3])): - x = (i - layout[2]) // (2*w) - y = (i - layout[2]) % (2*w) - lx = max(0, x - k1h // 2) - ly = max(0, y - k1 // 2) - rx = min(2*h-1, x + k1h // 2) - ry = min(2*w-1, y + k1 // 2) - m[i, layout[2]:layout[3]].view(h*2, w*2)[lx:x, ly:ry+1] = True - m[i, layout[2]:layout[3]].view(h*2, w*2)[x, ly:y+1] = True - x = x // 2 - y = y // 2 - lx = max(0, x - k2 // 2) - ly = max(0, y - k2 // 2) - rx = min(h-1, x + k2 // 2) - ry = min(w-1, y + k2 // 2) - m[i, layout[1]:layout[2]].view(h, w)[lx:rx+1, ly:ry+1] = True - m[i, i] = False - return m.unsqueeze(-1) - -def update_mems_local(hiddens, mems, start, end, mem_bag, mask): - if isinstance(hiddens, list): - hiddens = torch.stack(hiddens) - mem_bag[:, :, start:end] = hiddens.to('cuda') - # first level - del mems - mems = mem_bag.masked_select(mask[end]).view(*mem_bag.shape[:2], -1, mem_bag.shape[3]).to(hiddens.device) - return mems - - -def filling_sequence_local( - model, - seq, - args, - mems=None, - invalid_slices=[], - **kwargs): - ''' - seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1] - context_length: first non(-1)s - ''' - loss_sum, loss_n = 0, 0.0001 - tokenizer = get_tokenizer() - device = seq.device - assert len(seq.shape) == 1 - assert args.sparse_config.sparse_type == 'standard' - sparse_config = deepcopy(args.sparse_config) - # with open('tmp_save.bin', 'rb') as fout: - # seq = torch.load(fout)[1] - # seq = torch.cat((seq, torch.tensor([-1], device=seq.device))) - # import pdb; pdb.set_trace() - # sparse_config.layout[0] = seq.tolist().index(-1) - sparse_config.layout[0] = seq.tolist().index(tokenizer['[BASE]']) - n_pad = sparse_config.layout[1] - sparse_config.layout[0] - assert n_pad > 0 # TODO long trunc - seq = torch.cat((seq[:sparse_config.layout[0]], torch.tensor([tokenizer['[POS8]']]* n_pad, device=seq.device, dtype=seq.dtype), seq[sparse_config.layout[0]:])) - out_seq_length = len(seq) - # building the initial tokens, attention_mask, and position_ids - context_length = sparse_config.layout[1] + 1 - - tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args) - tokens = tokens.expand(-min(seq), *tokens.shape[1:]) - - counter = context_length - 1 # == len(tokens) - 1 - index = 0 # len(mems) - if mems is None: - mems = [] - mem_bag = torch.zeros(args.num_layers, tokens.shape[0], out_seq_length-1, 2*args.hidden_size, device='cuda') - local_mask = make_local_mask(sparse_config) - - while counter < (out_seq_length - 1): - if counter % 100 == 0: - print(counter, loss_sum / loss_n, file=sys.stderr) - # Now, we want to generate seq[counter + 1] - # token[:, index: counter+1] are just added. - # import pdb;pdb.set_trace() - - if index == 0: # first - logits, *qkv = model(tokens, position_ids, attention_mask, *mems) - mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask) - index = counter - else: - assert tokens.shape[1] == counter + 1 - position_ids = torch.arange(index, counter + 1, dtype=torch.long, device=tokens.device).unsqueeze(0) - logits, *qkv = model(tokens[:, index: ], - position_ids, - 0, # rebuild in transformers (sep version) - *mems - ) - mems = update_mems_local(qkv, mems, index, counter+1, mem_bag, local_mask) - index = counter - counter += 1 - index += 1 - - if seq[counter] >= 0: # provided - tokens = torch.cat((tokens, seq[counter: counter+1].expand(tokens.shape[0], 1)), dim=1) - loss_n +=1 - loss_this = F.log_softmax(logits, dim=-1)[:, -1, seq[counter]].mean() - print(counter-64, loss_this.item()) - loss_sum -= loss_this - continue - - logits = logits[:, -1] # [batch size, vocab size] - temp = args.temperature - logits /= temp - for invalid_slice in invalid_slices: # forbide to generate other tokens - logits[..., invalid_slice] = -float('Inf') - logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p) - log_probs = F.softmax(logits, dim=-1) - - # expand beams - prev = torch.multinomial(log_probs, num_samples=1) - tokens = torch.cat((tokens, prev.view(tokens.shape[0], 1)), dim=1) - - output_tokens_list = tokens.view(tokens.shape[0], -1).contiguous() - output_tokens_list = torch.cat( - ( - output_tokens_list[:, :sparse_config.layout[0]], - output_tokens_list[:, sparse_config.layout[1]+1:sparse_config.layout[2]+1], - torch.tensor([[tokenizer['[EOI1]']]], dtype=tokens.dtype, device=tokens.device).expand(output_tokens_list.shape[0], 1), - output_tokens_list[:, sparse_config.layout[2]+1:] - ), dim=1) - return output_tokens_list \ No newline at end of file diff --git a/generation/sampling_strategies/__init__.py b/generation/sampling_strategies/__init__.py new file mode 100644 index 0000000..f109603 --- /dev/null +++ b/generation/sampling_strategies/__init__.py @@ -0,0 +1 @@ +from .base_strategy import BaseStrategy \ No newline at end of file diff --git a/generation/sampling_strategies/base_strategy.py b/generation/sampling_strategies/base_strategy.py new file mode 100644 index 0000000..99a0a74 --- /dev/null +++ b/generation/sampling_strategies/base_strategy.py @@ -0,0 +1,46 @@ +# -*- encoding: utf-8 -*- +''' +@File : base_strategy.py +@Time : 2021/10/08 22:22:42 +@Author : Ming Ding +@Contact : dm18@mail.tsinghua.edu.cn +''' + +# here put the import lib +import os +import sys +import math +import random +import torch +import torch.nn.functional as F + +def top_k_logits_(logits, top_k=0, filter_value=-float('Inf')): + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + return logits + +class BaseStrategy: + def __init__(self, invalid_slices=[], temperature=1., topk=200, debias=False): + self.invalid_slices = invalid_slices + self.temperature = temperature + self.topk = topk + self.debias = debias + def forward(self, logits, tokens, mems, temperature=None): + if temperature is None: + temperature = self.temperature + logits = logits / temperature + for invalid_slice in self.invalid_slices: + logits[..., invalid_slice] = -float('Inf') + if self.debias: + probs = F.softmax(logits, dim=-1) + tk_value, tk_idx = torch.topk(probs, self.topk, dim=-1) + pred = torch.multinomial(probs, num_samples=1) + for j in range(0, pred.shape[0]): + if probs[j, pred[j,-1]] < tk_value[j, -1]: + pred[j, -1] = tk_idx[j, torch.randint(tk_idx.shape[-1]-100, tk_idx.shape[-1], (1,))] # 100 is the last N as outlier, which is chosen casually + else: + logits = top_k_logits_(logits) + probs = F.softmax(logits, dim=-1) + pred = torch.multinomial(probs, num_samples=1) + tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1) + return tokens, mems diff --git a/model/cached_autoregressive_model.py b/model/cached_autoregressive_model.py index ec2dd60..225ff3e 100755 --- a/model/cached_autoregressive_model.py +++ b/model/cached_autoregressive_model.py @@ -31,7 +31,8 @@ class CachedAutoregressiveModel(BaseModel): mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) if mem is not None: # the first time, mem is None - memk, memv = split_tensor_along_last_dim(mem, 2) + b = mixed_key_layer.shape[0] # might change batch_size + memk, memv = split_tensor_along_last_dim(mem.expand(b, -1, -1), 2) mixed_key_layer = torch.cat((memk, mixed_key_layer), dim=1) mixed_value_layer = torch.cat((memv, mixed_value_layer), dim=1) diff --git a/model/mixins.py b/model/mixins.py index 0d304cf..78befb8 100644 --- a/model/mixins.py +++ b/model/mixins.py @@ -34,9 +34,9 @@ class PositionEmbeddingMixin(BaseMixin): torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) def reinit(self, transformer, *pre_mixins): old_weights = transformer.position_embeddings.weight.data[self.reinit_slice] - old_len, hidden_size = old_weights.shape[0] + old_len, hidden_size = old_weights.shape assert hidden_size == self.position_embeddings.weight.shape[-1] - self.position_embeddings_plus.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) + self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights) class AttentionMixin(BaseMixin): def __init__(self, num_layers, diff --git a/mpu/transformer.py b/mpu/transformer.py index 099dcb5..61f658e 100755 --- a/mpu/transformer.py +++ b/mpu/transformer.py @@ -1,4 +1,4 @@ -# coding=utf- +# coding=utf-8 # rewritten, Copyright (c) 2021, Ming Ding. All rights reserved. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # @@ -26,7 +26,6 @@ from .initialize import get_model_parallel_world_size from .layers import ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding from .mappings import gather_from_model_parallel_region, copy_to_model_parallel_region -import deepspeed from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint, get_cuda_rng_tracker from .utils import divide, sqrt, scaled_init_method, unscaled_init_method, gelu @@ -248,7 +247,7 @@ class BaseTransformerLayer(torch.nn.Module): # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # MLP. - mlp_output = self.mlp(layernorm_output, *other_tensors) + mlp_output = self.mlp(layernorm_output, *other_tensors) # Fourth LayerNorm if self.sandwich_ln: diff --git a/pretrain_cogview2.py b/pretrain_cogview2.py index 032fade..5fdb431 100755 --- a/pretrain_cogview2.py +++ b/pretrain_cogview2.py @@ -1,6 +1,6 @@ # -*- encoding: utf-8 -*- ''' -@File : pretrain_gpt2.py +@File : pretrain_cogview2.py @Time : 2021/10/06 00:58:32 @Author : Ming Ding @Contact : dm18@mail.tsinghua.edu.cn diff --git a/scripts/finetune_into_cogview2.sh b/scripts/finetune_into_cogview2.sh new file mode 100755 index 0000000..3bdf5ea --- /dev/null +++ b/scripts/finetune_into_cogview2.sh @@ -0,0 +1,59 @@ +#! /bin/bash + +# Change for multinode config + +NUM_WORKERS=1 +NUM_GPUS_PER_WORKER=8 +MP_SIZE=1 + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +main_dir=$(dirname $script_dir) + +OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" +HOST_FILE_PATH="hostfile" +HOST_FILE_PATH="hostfile_single" + +full_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/merge.bin" +small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/zijian/zijian.bin.part_0.cogdata" + +config_json="$script_dir/ds_config_zero.json" +gpt_options=" \ + --experiment-name finetune-cogview2-test \ + --tokenizer-type cogview \ + --img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \ + --model-parallel-size ${MP_SIZE} \ + --mode finetune \ + --num-layers 48 \ + --hidden-size 2560 \ + --num-attention-heads 40 \ + --train-iters 200000 \ + --resume-dataloader \ + --train-data ${full_data} \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr-decay-style cosine \ + --warmup .1 \ + --checkpoint-activations \ + --max-sequence-length 1089 \ + --sandwich-ln \ + --fp16 \ + --save-interval 2000 \ + --eval-interval 1000 \ + --save $main_dir/checkpoints \ + --load pretrained/cogview/cogview-base +" + # --load pretrained/cogview/cogview-base + + +gpt_options="${gpt_options} + --deepspeed \ + --deepspeed_config ${config_json} \ +" + + +run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_cogview2.py $@ ${gpt_options}" +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/training/model_io.py b/training/model_io.py index f7527f9..ee2e8f1 100644 --- a/training/model_io.py +++ b/training/model_io.py @@ -135,7 +135,7 @@ def load_checkpoint(model, args): if not args.do_train: raise ValueError(f'Missing keys for inference: {missing_keys}.') else: # new params - assert all(name.find('mixins')>0 for name in missing_keys) + assert all(name.find('mixins')>=0 for name in missing_keys) module.reinit() # initialize mixins model.optimizer.refresh_fp32_params() # restore fp32 weights from module -- GitLab