Skip to content
Snippets Groups Projects
Commit f8d2632a authored by Ming Ding's avatar Ming Ding
Browse files

tmp save dev to dzx review glm

parent ebd91ef0
No related branches found
No related tags found
No related merge requests found
Showing
with 1011 additions and 682 deletions
......@@ -14,7 +14,7 @@ import random
import torch
from .sampling_strategies import BaseStrategy
def get_masks_and_position_ids(seq):
def get_masks_and_position_ids_default(seq):
tokens = seq.unsqueeze(0)
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
......@@ -36,7 +36,6 @@ def update_mems(hiddens, mems, max_memory_length):
memory_length = mems.shape[2] if mems is not None else 0
query_length = hiddens.shape[2]
new_memory_length = min(max_memory_length, memory_length + query_length)
new_mems = []
with torch.no_grad():
if new_memory_length <= query_length:
return hiddens[:, :, -new_memory_length:]
......@@ -55,10 +54,16 @@ def filling_sequence(
batch_size,
strategy=BaseStrategy(),
max_memory_length=100000,
log_attention_weights=None
log_attention_weights=None,
get_masks_and_position_ids=get_masks_and_position_ids_default,
mems=None
):
'''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
mems: [num_layers, batch_size, len_mems(index), mem_hidden_size]
cache, should be first mems.shape[1] parts of context_tokens.
mems are the first-level citizens here, but we don't assume what is memorized.
input mems are used when multi-phase generation.
'''
assert len(seq.shape) == 1
......@@ -72,9 +77,7 @@ def filling_sequence(
attention_mask = attention_mask.type_as(next(model.parameters())) # if fp16
# initialize generation
counter = context_length - 1 # Last fixed index is ``counter''
index = 0 # Next forward starting index, also the length of cache.
mems = None # mems are the first-level citizens here, but we don't assume what is memorized.
index = 0 if mems is None else mems.shape[2] # Next forward starting index, also the length of cache.
# step-by-step generation
while counter < len(seq) - 1:
# Now, we want to generate seq[counter + 1],
......@@ -83,7 +86,7 @@ def filling_sequence(
if seq[counter + 1] >= 0: # provided
tokens = torch.cat(
(
tokens,
tokens,
seq[counter+1: counter+2].expand(tokens.shape[0], 1)
), dim=1
)
......@@ -92,13 +95,16 @@ def filling_sequence(
# forward
if log_attention_weights is not None:
model.log_attention_weights = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
kw_tensors = {'mems': mems} if mems is not None else {}
log_attention_weights_part = log_attention_weights[..., index: counter+1, :counter+1] # TODO memlen
else:
log_attention_weights_part = None
logits, *mem_kv = model(
tokens[:, index:],
position_ids[..., index: counter+1],
attention_mask[..., index: counter+1, :counter+1], # TODO memlen
**kw_tensors # if no mems, cannot pass
mems=mems,
log_attention_weights=log_attention_weights_part
)
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
......@@ -107,6 +113,6 @@ def filling_sequence(
logits = logits[:, -1].expand(batch_size, -1) # [batch size, vocab size]
tokens = tokens.expand(batch_size, -1)
tokens, mems = strategy.forward(logits, tokens, mems)
model.log_attention_weights = None
return tokens
\ No newline at end of file
if strategy.is_done:
break
return strategy.finalize(tokens, mems)
\ No newline at end of file
......@@ -15,7 +15,7 @@ import torch
import torch.nn.functional as F
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-65504):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
......@@ -69,10 +69,11 @@ class BaseStrategy:
logits = top_k_logits(logits, self.topk, self.top_p)
probs = F.softmax(logits.float(), dim=-1) # float is essetial, due to a bug in Pytorch
pred = torch.multinomial(probs, num_samples=1)
if pred.item() in self.end_tokens:
if pred.numel() == 1 and pred.item() in self.end_tokens:
self._is_done = True
tokens = torch.cat((tokens, pred.view(tokens.shape[0], 1)), dim=1)
return tokens, mems
def finalize(self, tokens, mems):
self._is_done = False
return tokens, mems
This diff is collapsed.
This diff is collapsed.
# -*- encoding: utf-8 -*-
'''
@File : inference_cogview.py
@Time : 2021/10/09 19:41:58
@File : inference_glm.py
@Time : 2021/10/22 19:41:58
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
from functools import partial
import os
import sys
import random
......@@ -14,163 +15,138 @@ import time
from datetime import datetime
import torch
import torch.nn.functional as F
import argparse
import stat
import mpu
from functools import partial
from arguments import get_args
from model.glm_model import GLMModel
from model.cached_autoregressive_model import CachedAutoregressiveMixin
from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
from generation.glm_sampling import filling_sequence_glm
from generation.autoregressive_sampling import filling_sequence
from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
from generation.utils import timed_name, generate_continually
def get_masks_and_position_ids_glm(seq, mask_position, context_length):
tokens = seq.unsqueeze(0)
def read_context(tokenizer, args, output=None):
terminate_runs, skip_run = 0, 0
if mpu.get_model_parallel_rank() == 0:
while True:
raw_text = input("\nContext prompt (stop to exit) >>> ")
if not raw_text:
print('Prompt should not be empty!')
continue
if raw_text == "stop":
terminate_runs = 1
break
generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
if args.block_lm and 'MASK]' not in raw_text:
raw_text += ' ' + generation_mask
if output is not None:
output.write(raw_text)
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
if args.block_lm:
context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
if not raw_text.endswith('MASK]'):
context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
context_length = len(context_tokens)
if context_length >= args.max_sequence_length:
print("\nContext length", context_length,
"\nPlease give smaller context than the window length!")
continue
break
else:
context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return terminate_runs, None, None, None
context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor[0].item()
if mpu.get_model_parallel_rank() == 0:
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
else:
context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
if mpu.get_model_parallel_rank() != 0:
raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
return terminate_runs, raw_text, context_tokens_tensor, context_length
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
attention_mask.tril_()
attention_mask.unsqueeze_(1)
position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
torch.arange(0, context_length, out=position_ids[0, :context_length])
position_ids[0, context_length:] = mask_position
torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:])
def get_batch(context_tokens, args):
tokens = context_tokens
tokens = tokens.view(1, -1).contiguous()
tokens = tokens.to('cuda')
# Get the masks and postition ids.
if args.block_lm:
attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
if args.fp16:
attention_mask = attention_mask.half()
position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
if not args.no_block_position:
block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
position_ids = position_ids.unsqueeze(0)
else:
raise NotImplementedError
position_ids = position_ids.unsqueeze(0)
return tokens, attention_mask, position_ids
def generate_samples(model, tokenizer, args):
model.eval()
output_path = "./samples"
if not os.path.exists(output_path):
os.makedirs(output_path)
output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
with torch.no_grad(), open(output_path, "w") as output:
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
if terminate_runs == 1:
return
start_time = time.time()
if args.block_lm:
mems = []
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
mask_positions = []
for token in mask_tokens:
mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
mask_positions.sort()
if args.no_block_position:
for mask_position in mask_positions:
position_ids[0, mask_position + 1:] += args.out_seq_length
_, *mems = model(tokens, position_ids, attention_mask, *mems)
for mask_position in mask_positions:
if args.no_block_position:
position = position_ids[0, mask_position].item()
else:
position = mask_position
if args.num_beams > 1:
strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
length_penalty=args.length_penalty, end_tokens=end_tokens,
no_repeat_ngram_size=args.no_repeat_ngram_size,
min_tgt_length=args.min_tgt_length)
else:
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
end_tokens=end_tokens)
new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
end_tokens=end_tokens)
tokens = torch.cat((tokens, new_tokens), dim=1)
output_tokens_list = tokens.view(-1).contiguous()
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
trim_decode_tokens = decode_tokens
print("\nGLM:", trim_decode_tokens, flush=True)
output.write(trim_decode_tokens + "\n")
torch.distributed.barrier(group=mpu.get_model_parallel_group())
def main(args):
initialize_distributed(args)
tokenizer = prepare_tokenizer(args)
# build model
# build model
model = GLMModel(args)
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
if args.fp16:
model = model.half()
model = model.to(args.device)
load_checkpoint(model, args)
set_random_seed(args.seed)
model.eval()
generate_samples(model, tokenizer, args)
end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
# define function for each query
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k,end_tokens=end_tokens)
def process(raw_text):
if args.with_id:
query_id, raw_text = raw_text.split('\t')
# add MASK
generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
if 'MASK]' not in raw_text:
raw_text += ' ' + generation_mask
seq = tokenizer.EncodeAsIds(raw_text).tokenization
seq = [tokenizer.get_command('ENC').Id] + seq
if not raw_text.endswith('MASK]'):
seq = seq + [tokenizer.get_command('eos').Id]
print('raw text: ', raw_text)
if len(seq) > args.max_sequence_length:
raise ValueError('text too long.')
# find mask tokens positions
mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
mask_positions = []
context_tokens_tensor = torch.tensor(seq, dtype=torch.long, device=args.device)
for token in mask_tokens:
mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
mask_positions.sort()
# generation
mbz = args.max_inference_batch_size
assert args.batch_size < mbz or args.batch_size % mbz == 0
output_list = []
# call for each position
for mp_idx, mask_position in enumerate(mask_positions):
get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=len(seq))
for tim in range(max(args.batch_size // mbz, 1)):
input_seq = torch.cuda.LongTensor(seq + [tokenizer.get_command('sop').Id] + [-1] * (args.out_seq_length-len(seq)-1), device=args.device)
output, _mems = filling_sequence(model, input_seq,
batch_size=min(args.batch_size, mbz),
strategy=strategy,
log_attention_weights=None,
get_masks_and_position_ids=get_func
) # we don't use mems, fill back
if isinstance(output, torch.Tensor): # different strategies
output = list(output)
output_list.extend(output)
# clip -1s and fill back generated things into seq
for i in range(len(output_list)):
output = output_list[i].tolist()
try:
unfinished = output.index(-1)
except ValueError:
unfinished = len(output)
bog = output.index(tokenizer.get_command('sop').Id)
output_list[i] = output[:mask_position] + output[bog+1:unfinished] + output[mask_position+1:bog]
# prepare the next auto-regressive generation
if mp_idx < len(mask_positions) - 1:
# TODO, here to select the best for this time, inverse prompting?
seq = output_list[0]
output_list = []
# decoding
txts = []
for seq in output_list:
decode_tokens = tokenizer.DecodeIds(seq)
txts.append(decode_tokens)
# save
if args.with_id:
full_path = os.path.join(args.output_path, query_id + '.txt')
else:
prefix = raw_text.replace('/', '')[:20]
full_path = timed_name(prefix, '.txt', args.output_path)
print(txts[0]) # print the first.
with open(full_path, 'w') as fout:
for txt in txts:
fout.write(txt + '\n')
os.chmod(full_path, stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
os.makedirs(args.output_path, exist_ok=True)
generate_continually(process, args.input_source)
if __name__ == "__main__":
args = get_args()
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
with torch.no_grad():
main(args)
main(args)
\ No newline at end of file
# -*- encoding: utf-8 -*-
'''
@File : inference_cogview.py
@Time : 2021/10/09 19:41:58
@Author : Ming Ding
@Contact : dm18@mail.tsinghua.edu.cn
'''
# here put the import lib
import os
import sys
import random
import time
from datetime import datetime
import torch
import torch.nn.functional as F
import mpu
from arguments import get_args
from model.glm_model import GLMModel
from training import load_checkpoint, initialize_distributed, set_random_seed, prepare_tokenizer
from generation.glm_sampling import filling_sequence_glm
from generation.sampling_strategies import BeamSearchStrategy, BaseStrategy
def read_context(tokenizer, args, output=None):
terminate_runs, skip_run = 0, 0
if mpu.get_model_parallel_rank() == 0:
while True:
raw_text = input("\nContext prompt (stop to exit) >>> ")
if not raw_text:
print('Prompt should not be empty!')
continue
if raw_text == "stop":
terminate_runs = 1
break
generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
if args.block_lm and 'MASK]' not in raw_text:
raw_text += ' ' + generation_mask
if output is not None:
output.write(raw_text)
context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
if args.block_lm:
context_tokens = [tokenizer.get_command('ENC').Id] + context_tokens
if not raw_text.endswith('MASK]'):
context_tokens = context_tokens + [tokenizer.get_command('eos').Id]
context_length = len(context_tokens)
if context_length >= args.max_sequence_length:
print("\nContext length", context_length,
"\nPlease give smaller context than the window length!")
continue
break
else:
context_length = 0
terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
terminate_runs = terminate_runs_tensor[0].item()
if terminate_runs == 1:
return terminate_runs, None, None, None
context_length_tensor = torch.cuda.LongTensor([context_length])
torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
context_length = context_length_tensor[0].item()
if mpu.get_model_parallel_rank() == 0:
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
else:
context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
if mpu.get_model_parallel_rank() != 0:
raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
return terminate_runs, raw_text, context_tokens_tensor, context_length
def get_batch(context_tokens, args):
tokens = context_tokens
tokens = tokens.view(1, -1).contiguous()
tokens = tokens.to('cuda')
# Get the masks and postition ids.
if args.block_lm:
attention_mask = torch.ones(tokens.size(1), tokens.size(1), device='cuda', dtype=torch.long)
if args.fp16:
attention_mask = attention_mask.half()
position_ids = torch.arange(tokens.size(1), device='cuda', dtype=torch.long)
if not args.no_block_position:
block_position_ids = torch.zeros(tokens.size(1), device='cuda', dtype=torch.long)
position_ids = torch.stack((position_ids, block_position_ids), dim=0)
position_ids = position_ids.unsqueeze(0)
else:
raise NotImplementedError
return tokens, attention_mask, position_ids
def generate_samples(model, tokenizer, args):
model.eval()
output_path = "./samples"
if not os.path.exists(output_path):
os.makedirs(output_path)
output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
with torch.no_grad(), open(output_path, "w") as output:
while True:
torch.distributed.barrier(group=mpu.get_model_parallel_group())
terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
if terminate_runs == 1:
return
start_time = time.time()
if args.block_lm:
mems = []
tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, args)
mask_tokens = ['MASK', 'sMASK', 'gMASK'] if args.task_mask else ['MASK']
mask_tokens = [tokenizer.get_command(token).Id for token in mask_tokens]
end_tokens = [tokenizer.get_command('eop').Id, tokenizer.get_command('eos').Id]
mask_positions = []
for token in mask_tokens:
mask_positions += (context_tokens_tensor == token).nonzero(as_tuple=True)[0].tolist()
mask_positions.sort()
if args.no_block_position:
for mask_position in mask_positions:
position_ids[0, mask_position + 1:] += args.out_seq_length
_, *mems = model(tokens, position_ids, attention_mask, *mems)
for mask_position in mask_positions:
if args.no_block_position:
position = position_ids[0, mask_position].item()
else:
position = mask_position
if args.num_beams > 1:
strategy = BeamSearchStrategy(num_beams=args.num_beams, max_length=args.out_seq_length,
length_penalty=args.length_penalty, end_tokens=end_tokens,
no_repeat_ngram_size=args.no_repeat_ngram_size,
min_tgt_length=args.min_tgt_length)
else:
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, top_p=args.top_p,
end_tokens=end_tokens)
new_tokens, mems = filling_sequence_glm(model, tokenizer, position, strategy, args, mems=mems,
end_tokens=end_tokens)
tokens = torch.cat((tokens, new_tokens), dim=1)
output_tokens_list = tokens.view(-1).contiguous()
if mpu.get_model_parallel_rank() == 0:
os.system('clear')
print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
print("\nContext:", raw_text, flush=True)
decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
trim_decode_tokens = decode_tokens
print("\nGLM:", trim_decode_tokens, flush=True)
output.write(trim_decode_tokens + "\n")
torch.distributed.barrier(group=mpu.get_model_parallel_group())
def main(args):
initialize_distributed(args)
tokenizer = prepare_tokenizer(args)
# build model
model = GLMModel(args)
if args.fp16:
model = model.half()
model = model.to(args.device)
load_checkpoint(model, args)
set_random_seed(args.seed)
model.eval()
generate_samples(model, tokenizer, args)
if __name__ == "__main__":
args = get_args()
with torch.no_grad():
main(args)
......@@ -14,11 +14,13 @@ import random
import torch
from mpu import BaseTransformer
from .mixins import BaseMixin
class BaseModel(torch.nn.Module):
def __init__(self, args, transformer=None):
super(BaseModel, self).__init__()
self.hooks = self.collect_hooks()
self.mixins = torch.nn.ModuleDict()
self.collect_hooks_()
if transformer is not None:
self.transformer = transformer
else:
......@@ -37,12 +39,25 @@ class BaseModel(torch.nn.Module):
parallel_output=True,
hooks=self.hooks
)
self.mixins = torch.nn.ModuleList()
def reinit(self):
def reinit(self): # will be called when loading model
# if some mixins are loaded, overrides this function
for m in self.mixins:
for m in self.mixins.values():
m.reinit(self.transformer)
def add_mixin(self, name, new_mixin, reinit=False):
assert name not in self.mixins
assert isinstance(new_mixin, BaseMixin)
self.mixins[name] = new_mixin # will auto-register parameters
object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
if reinit:
new_mixin.reinit(self.transformer, **self.mixins) # also pass current mixins
self.collect_hooks_()
def get_mixin(self, name):
return self.mixins[name]
def forward(self, *args, **kwargs):
# update hooks as the current model (overrided forwards)
......@@ -51,16 +66,28 @@ class BaseModel(torch.nn.Module):
self.transformer.hooks.update(self.hooks)
return self.transformer(*args, **kwargs)
def collect_hooks(self):
def collect_hooks_(self):
names = ['word_embedding_forward', 'position_embedding_forward',
'attention_forward', 'mlp_forward', 'final_forward', 'layer_forward',
'branch_embedding_forward', 'branch_final_forward'
]
hooks = {}
hook_origins = {}
for name in names:
for mixin_name, m in self.mixins.items():
if hasattr(m, name):
if name in hooks: # conflict
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
hooks[name] = getattr(m, name)
hook_origins[name] = mixin_name
if hasattr(self, name):
# if name in hooks: # defined in mixins, can override
# print(f'Override {name} in {hook_origins[name]}...')
hooks[name] = getattr(self, name)
hook_origins[name] = 'model'
self.hooks = hooks
self.hook_origins = hook_origins
return hooks
def disable_untrainable_params(self):
pass
\ No newline at end of file
......@@ -13,15 +13,15 @@ import math
import random
import torch
from .mixins import BaseMixin
from .base_model import BaseModel
from mpu.transformer import standard_attention, split_tensor_along_last_dim
class CachedAutoregressiveModel(BaseModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
self.log_attention_weights = None
class CachedAutoregressiveMixin(BaseMixin):
def __init__(self):
super().__init__()
def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, **kwargs):
def attention_forward(self, hidden_states, mask, mems=None, layer_id=None, log_attention_weights=None, **kwargs):
attn_module = self.transformer.layers[layer_id].attention
mem = mems[layer_id] if mems is not None else None
......@@ -40,7 +40,7 @@ class CachedAutoregressiveModel(BaseModel):
query_layer = attn_module._transpose_for_scores(mixed_query_layer)
key_layer = attn_module._transpose_for_scores(mixed_key_layer)
value_layer = attn_module._transpose_for_scores(mixed_value_layer)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=self.log_attention_weights)
context_layer = standard_attention(query_layer, key_layer, value_layer, mask, None, log_attention_weights=log_attention_weights)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (attn_module.hidden_size_per_partition,)
......@@ -51,3 +51,8 @@ class CachedAutoregressiveModel(BaseModel):
new_mem = mixed_raw_layer.detach()[..., -(mixed_raw_layer.shape[-1] // 3 * 2):].contiguous()
return output, new_mem
class CachedAutoregressiveModel(BaseModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
self.add_mixin('auto-regressive', CachedAutoregressiveMixin())
......@@ -28,10 +28,10 @@ class Cuda2dModel(BaseModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
additional_seqlen = args.new_sequence_length - args.max_sequence_length
self.mixins.append(PositionEmbeddingMixin(
self.add_mixin('extra_position_embedding', PositionEmbeddingMixin(
additional_seqlen, args.hidden_size
))
self.mixins.append(AttentionMixin(
self.add_mixin('attention_plus', AttentionMixin(
num_layers=args.num_layers,
hidden_size=args.hidden_size
))
......@@ -41,23 +41,24 @@ class Cuda2dModel(BaseModel):
self.kernel_size2 = args.kernel_size2
self.log_attention_weights = None
def position_embedding_forward(self, position_ids, **kw_tensors):
def position_embedding_forward(self, position_ids, **kw_args):
position = position_ids[..., :self.layout[1]]
position_plus = position_ids[..., self.layout[1]:]
position_embeddings = torch.cat(
(
self.transformer.position_embeddings(position),
self.mixins[0].position_embeddings(position_plus)
self.get_mixin('extra_position_embedding').position_embeddings(position_plus)
),
dim=-2
)
return position_embeddings
def attention_forward(self, hidden_states, mask, layer_id=None, **kw_tensors):
def attention_forward(self, hidden_states, mask,
layer_id=None, log_attention_weights=None, **kw_args):
attn_module = self.transformer.layers[layer_id].attention
# attention_plus on all layers
query_key_value_plus = self.mixins[1].query_key_value[layer_id]
dense_plus = self.mixins[1].dense[layer_id]
query_key_value_plus = self.get_mixin('attention_plus').query_key_value[layer_id]
dense_plus = self.get_mixin('attention_plus').dense[layer_id]
# split two parts
hidden_states_plus = hidden_states[:, self.layout[1]:]
......@@ -81,7 +82,7 @@ class Cuda2dModel(BaseModel):
kernel_size=self.kernel_size,
kernel_size2=self.kernel_size2,
attention_dropout=dropout_fn,
log_attention_weights=self.log_attention_weights
log_attention_weights=log_attention_weights
)
output_0 = attn_module.dense(context_layer0)
......
......@@ -3,16 +3,26 @@ import torch.nn as nn
from .base_model import BaseModel
from .cached_autoregressive_model import CachedAutoregressiveModel
from .mixins import BaseMixin
class GLMModel(CachedAutoregressiveModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
self.transformer.block_position_embeddings = torch.nn.Embedding(args.max_sequence_length, args.hidden_size)
torch.nn.init.normal_(self.transformer.block_position_embeddings.weight, mean=0.0, std=0.02)
def position_embedding_forward(self, position_ids, *other_tensors):
class BlockPositionEmbeddingMixin(BaseMixin):
def __init__(self, max_sequence_length, hidden_size, init_method_std=0.02):
super(BlockPositionEmbeddingMixin, self).__init__()
self.max_sequence_length = max_sequence_length
self.hidden_size = hidden_size
self.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)
def position_embedding_forward(self, position_ids, **kwargs):
position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
position_embeddings = self.transformer.position_embeddings(position_ids)
block_position_embeddings = self.transformer.block_position_embeddings(block_position_ids)
block_position_embeddings = self.block_position_embeddings(block_position_ids)
return position_embeddings + block_position_embeddings
class GLMModel(BaseModel):
def __init__(self, args, transformer=None):
super().__init__(args, transformer=transformer)
self.add_mixin('block_position_embedding',
BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)
)
......@@ -20,9 +20,11 @@ class BaseMixin(torch.nn.Module):
def __init__(self):
super(BaseMixin, self).__init__()
# define new params
def reinit(self, transformer, *pre_mixins):
def reinit(self, *pre_mixins):
# reload the initial params from previous trained modules
pass
# can also define hook-functions here
# ...
class PositionEmbeddingMixin(BaseMixin):
def __init__(self, additional_sequence_length, hidden_size,
......@@ -32,8 +34,8 @@ class PositionEmbeddingMixin(BaseMixin):
self.reinit_slice = reinit_slice
self.position_embeddings = torch.nn.Embedding(additional_sequence_length, hidden_size)
torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)
def reinit(self, transformer, *pre_mixins):
old_weights = transformer.position_embeddings.weight.data[self.reinit_slice]
def reinit(self, *pre_mixins):
old_weights = self.transformer.position_embeddings.weight.data[self.reinit_slice]
old_len, hidden_size = old_weights.shape
assert hidden_size == self.position_embeddings.weight.shape[-1]
self.position_embeddings.weight.data.view(-1, old_len, hidden_size).copy_(old_weights)
......@@ -58,11 +60,11 @@ class AttentionMixin(BaseMixin):
init_method=output_layer_init_method)
for layer_id in range(num_layers)
])
def reinit(self, transformer, *pre_mixins):
start_layer = len(transformer.layers) - self.num_layers
def reinit(self, *pre_mixins):
start_layer = len(self.transformer.layers) - self.num_layers
assert start_layer >= 0
for layer_id in range(self.num_layers):
old_attention = transformer.layers[start_layer + layer_id].attention
old_attention = self.transformer.layers[start_layer + layer_id].attention
self.query_key_value[layer_id].weight.data.copy_(old_attention.query_key_value.weight.data)
self.query_key_value[layer_id].bias.data.copy_(old_attention.query_key_value.bias.data)
self.dense[layer_id].weight.data.copy_(old_attention.dense.weight.data)
......
......@@ -111,9 +111,9 @@ class SelfAttention(torch.nn.Module):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, mask, **kw_tensors):
def forward(self, hidden_states, mask, **kw_args):
if 'attention_forward' in self.hooks:
return self.hooks['attention_forward'](hidden_states, mask, **kw_tensors, layer_id=self.layer_id)
return self.hooks['attention_forward'](hidden_states, mask, **kw_args, layer_id=self.layer_id)
else:
mixed_raw_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
......@@ -162,9 +162,9 @@ class MLP(torch.nn.Module):
)
self.dropout = torch.nn.Dropout(output_dropout_prob)
def forward(self, hidden_states, **kw_tensors):
def forward(self, hidden_states, **kw_args):
if 'mlp_forward' in self.hooks:
output = self.hooks['mlp_forward'](hidden_states, **kw_tensors, layer_id=self.layer_id)
output = self.hooks['mlp_forward'](hidden_states, **kw_args, layer_id=self.layer_id)
else:
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = gelu(intermediate_parallel)
......@@ -227,7 +227,7 @@ class BaseTransformerLayer(torch.nn.Module):
hooks=hooks
)
def forward(self, hidden_states, mask, **kw_tensors):
def forward(self, hidden_states, mask, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
......@@ -236,7 +236,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm at the begining of the transformer layer.
layernorm_output1 = self.input_layernorm(hidden_states)
# Self attention.
attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_tensors)
attention_output, output_this_layer = self.attention(layernorm_output1, mask, **kw_args)
# Third LayerNorm
if self.sandwich_ln:
......@@ -247,7 +247,7 @@ class BaseTransformerLayer(torch.nn.Module):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output, **kw_tensors)
mlp_output = self.mlp(layernorm_output, **kw_args)
# Fourth LayerNorm
if self.sandwich_ln:
......@@ -316,27 +316,25 @@ class BaseTransformer(torch.nn.Module):
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_tensors):
def forward(self, input_ids, position_ids, attention_mask, *, branch_input=None, **kw_args):
# sanity check
assert len(input_ids.shape) == 2
batch_size, query_length = input_ids.shape
assert len(attention_mask.shape) == 2 or \
len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1
assert branch_input is None or 'layer_forward' in self.hooks and isinstance(branch_input, torch.Tensor)
for k, v in kw_tensors.items():
assert isinstance(v, torch.Tensor)
# branch_input is a new part of input need layer-by-layer update,
# but with different hidden_dim and computational routine.
# In most cases, you can just ignore it.
# embedding part
if 'word_embedding_forward' in self.hooks:
hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_tensors)
hidden_states = self.hooks['word_embedding_forward'](input_ids, **kw_args)
else: # default
hidden_states = self.word_embeddings(input_ids)
if 'position_embedding_forward' in self.hooks:
position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_tensors)
position_embeddings = self.hooks['position_embedding_forward'](position_ids, **kw_args)
else:
assert len(position_ids.shape) <= 2
assert position_ids.shape[-1] == query_length
......@@ -346,7 +344,7 @@ class BaseTransformer(torch.nn.Module):
# branch related embedding
if branch_input is None and 'branch_embedding_forward' in self.hooks:
branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_tensors)
branch_input = self.hooks['branch_embedding_forward'](branch_input, **kw_args)
# define custom_forward for checkpointing
output_per_layers = []
......@@ -361,14 +359,14 @@ class BaseTransformer(torch.nn.Module):
for i, layer in enumerate(layers_):
if len(inputs) > 2:
x_, branch_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_tensors
x_, mask, layer_id=layer.layer_id, branch_input=branch_, **kw_args
)
elif 'layer_forward' in self.hooks:
x_, output_this_layer = self.hooks['layer_forward'](
x_, mask, layer_id=layer.layer_id, **kw_tensors
x_, mask, layer_id=layer.layer_id, **kw_args
)
else:
x_, output_this_layer = layer(x_, mask, **kw_tensors)
x_, output_this_layer = layer(x_, mask, **kw_args)
output_per_layers_part.append(output_this_layer)
return x_, output_per_layers_part
return custom_forward
......@@ -387,25 +385,25 @@ class BaseTransformer(torch.nn.Module):
for i, layer in enumerate(self.layers):
args = [hidden_states, attention_mask]
if branch_input is not None: # customized layer_forward with branch_input
hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_tensors)
hidden_states, branch_input, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), branch_input=branch_input, **kw_args)
elif 'layer_forward' in self.hooks: # customized layer_forward
hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_tensors)
hidden_states, output_this_layer = self.hooks['layer_forward'](*args, layer_id=torch.tensor(i), **kw_args)
else:
hidden_states, output_this_layer = layer(*args, **kw_tensors)
hidden_states, output_this_layer = layer(*args, **kw_args)
output_per_layers.append(output_this_layer)
# Final layer norm.
logits = self.final_layernorm(hidden_states)
if 'final_forward' in self.hooks:
logits_parallel = self.hooks['final_forward'](logits, **kw_tensors)
logits_parallel = self.hooks['final_forward'](logits, **kw_args)
else:
logits_parallel = copy_to_model_parallel_region(logits)
logits_parallel = F.linear(logits_parallel, self.word_embeddings.weight)
# branch related embedding
if branch_input is None and 'branch_final_forward' in self.hooks:
branch_input = self.hooks['branch_final_forward'](branch_input, **kw_tensors)
branch_input = self.hooks['branch_final_forward'](branch_input, **kw_args)
if self.parallel_output:
logits_parallel = gather_from_model_parallel_region(logits_parallel)
......
......@@ -50,7 +50,7 @@ gpt_options="${gpt_options}
--deepspeed \
--deepspeed_config ${config_json} \
"
run_cmd="${OPTIONS_NCCL} deepspeed --num_nodes ${NUM_WORKERS} --num_gpus ${NUM_GPUS_PER_WORKER} --hostfile ${HOST_FILE_PATH} pretrain_cogview2.py $@ ${gpt_options}"
echo ${run_cmd}
......
#!/bin/bash
CHECKPOINT_PATH=/dataset/fd5061f6/english_data/checkpoints
CHECKPOINT_PATH=pretrained/glm
source $1
# MODEL_ARGS="--block-lm \
# --cloze-eval \
# --num-layers 24 \
# --hidden-size 1024 \
# --num-attention-heads 16 \
# --max-sequence-length 513 \
# --tokenizer-model-type roberta \
# --tokenizer-type glm_GPT2BPETokenizer \
# --load ${CHECKPOINT_PATH}/glm-roberta-large-blank"
MODEL_TYPE="blocklm-10B"
MODEL_ARGS="--block-lm \
--cloze-eval \
--task-mask \
--num-layers 48 \
--hidden-size 4096 \
--num-attention-heads 64 \
--max-sequence-length 1025 \
--tokenizer-model-type gpt2 \
--tokenizer-type glm_GPT2BPETokenizer \
--old-checkpoint \
--load ${CHECKPOINT_PATH}/glm-en-10b"
MPSIZE=1
MAXSEQLEN=512
......@@ -29,4 +50,7 @@ python -m torch.distributed.launch --nproc_per_node=$MPSIZE --master_port $MASTE
--out-seq-length $MAXSEQLEN \
--temperature $TEMP \
--top_k $TOPK \
--top_p $TOPP
--output-path glm_text \
--batch-size 1 \
--out-seq-length 100 \
--mode inference
......@@ -19,17 +19,17 @@ small_data="/dataset/fd5061f6/cogview/cogdata_new/cogdata_task_4leveltokens/ziji
config_json="$script_dir/ds_config_zero.json"
gpt_options=" \
--experiment-name pretrain-gpt2-cogview-test \
--experiment-name pretrain-gpt2-cogview-small \
--tokenizer-type cogview \
--img-tokenizer-path pretrained/vqvae/vqvae_hard_biggerset_011.pt \
--model-parallel-size ${MP_SIZE} \
--mode pretrain \
--num-layers 12 \
--hidden-size 1024 \
--num-attention-heads 16 \
--num-layers 40 \
--hidden-size 2048 \
--num-attention-heads 32 \
--train-iters 200000 \
--resume-dataloader \
--train-data ${small_data} \
--train-data ${full_data} \
--split 949,50,1 \
--distributed-backend nccl \
--lr-decay-style cosine \
......@@ -38,9 +38,9 @@ gpt_options=" \
--max-sequence-length 1089 \
--sandwich-ln \
--fp16 \
--save-interval 2000 \
--save-interval 5000 \
--eval-interval 1000 \
--save $main_dir/checkpoints \
--save /root/checkpoints \
"
# --load pretrained/cogview/cogview-base
......
......@@ -33,7 +33,7 @@ MASTER_PORT=${MASTER_PORT} python inference_cogview.py \
--sandwich-ln \
--input-source ./input.txt \
--output-path samples_text2image \
--batch-size 8 \
--batch-size 4 \
--max-inference-batch-size 8 \
$@
......
......@@ -15,8 +15,6 @@ import torch
def get_tokenizer(args=None):
kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask,
"add_decoder_mask": args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0}
if not hasattr(get_tokenizer, 'tokenizer'):
# the first time to load the tokenizer
if args.tokenizer_type == 'cogview':
......@@ -25,15 +23,18 @@ def get_tokenizer(args=None):
args.img_tokenizer_path,
device=torch.cuda.current_device()
)
elif args.tokenizer_type == "BertWordPieceTokenizer":
from .text import BertWordPieceTokenizer
get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs)
elif args.tokenizer_type == "GPT2BPETokenizer":
from .text import GPT2BPETokenizer
get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
elif args.tokenizer_type == "ChineseSPTokenizer":
from .text import ChineseSPTokenizer
get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs)
elif args.tokenizer_type.startswith('glm_'):
kwargs = {"add_block_symbols": args.block_lm, "add_task_mask": args.task_mask,
"add_decoder_mask": False} #args.block_mask_prob > 0.0 or args.context_mask_ratio > 0.0}
if args.tokenizer_type == "glm_BertWordPieceTokenizer":
from .text import BertWordPieceTokenizer
get_tokenizer.tokenizer = BertWordPieceTokenizer(args.tokenizer_model_type, **kwargs)
elif args.tokenizer_type == "glm_GPT2BPETokenizer":
from .text import GPT2BPETokenizer
get_tokenizer.tokenizer = GPT2BPETokenizer(args.tokenizer_model_type, **kwargs)
elif args.tokenizer_type == "glm_ChineseSPTokenizer":
from .text import ChineseSPTokenizer
get_tokenizer.tokenizer = ChineseSPTokenizer(**kwargs)
else:
assert args.vocab_size > 0
get_tokenizer.tokenizer = FakeTokenizer(args.vocab_size)
......
......@@ -36,12 +36,12 @@ from ..file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-vocab.json",
"roberta": "pretrained/pytorch_pretrained_bert/roberta-vocab.json"
'gpt2': "pretrained/english_tokenizer/gpt2-vocab.json",
"roberta": "pretrained/english_tokenizer/roberta-vocab.json"
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "pretrained/pytorch_pretrained_bert/gpt2-merges.txt",
"roberta": "pretrained/pytorch_pretrained_bert/roberta-merges.txt"
'gpt2': "pretrained/english_tokenizer/gpt2-merges.txt",
"roberta": "pretrained/english_tokenizer/roberta-merges.txt"
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
......
......@@ -131,11 +131,6 @@ def load_checkpoint(model, args):
else: # inference without deepspeed
module = model
# Process the checkpoint for GLM
if args.block_lm and args.old_checkpoint:
sd['module']['transformer.word_embeddings.weight'] = sd['module']['word_embeddings.weight']
del sd['module']['word_embeddings.weight']
# only load module, other hyperparameters are just for recording.
missing_keys, unexpected_keys = module.load_state_dict(sd['module'], strict=False)
if len(unexpected_keys) > 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment