Skip to content
Snippets Groups Projects
Commit 3ab8731b authored by Ming Ding's avatar Ming Ding
Browse files

Merge branch 'tmp_eval_support' into enc-dec

parents 00442c04 44333e0f
Branches
Tags
No related merge requests found
...@@ -148,7 +148,8 @@ def add_evaluation_args(parser): ...@@ -148,7 +148,8 @@ def add_evaluation_args(parser):
'validation/test for') 'validation/test for')
group.add_argument('--eval-interval', type=int, default=1000, group.add_argument('--eval-interval', type=int, default=1000,
help='interval between running evaluation on validation set') help='interval between running evaluation on validation set')
group.add_argument('--strict-eval', action='store_true',
help='won\'t enlarge or randomly map eval-data, and eval full eval-data.')
return parser return parser
......
...@@ -24,14 +24,17 @@ from .samplers import DistributedBatchSampler ...@@ -24,14 +24,17 @@ from .samplers import DistributedBatchSampler
from SwissArmyTransformer import mpu from SwissArmyTransformer import mpu
def make_data_loader(dataset, batch_size, num_iters, args): def make_data_loader(dataset, batch_size, args):
world_size = torch.distributed.get_world_size( world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group()) group=mpu.get_data_parallel_group())
rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
distributed = world_size > 1 distributed = world_size > 1
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
drop_last = distributed # drop_last = distributed
drop_last = True # TODO will always drop last to keep the consistency.
# or, how to avg in eval last batch?
# the GPUs in the same model parallel group receive the same data # the GPUs in the same model parallel group receive the same data
if distributed: # TODO reformat this, but it is not urgent if distributed: # TODO reformat this, but it is not urgent
gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1) gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
...@@ -52,7 +55,7 @@ def make_data_loader(dataset, batch_size, num_iters, args): ...@@ -52,7 +55,7 @@ def make_data_loader(dataset, batch_size, num_iters, args):
return data_loader return data_loader
def make_dataset_full(path, split, args, create_dataset_function, **kwargs): def make_dataset_full(path, split, args, create_dataset_function, random_mapping=True, **kwargs):
"""function to create datasets+tokenizers for common options""" """function to create datasets+tokenizers for common options"""
print('make dataset ...', path) print('make dataset ...', path)
if split is None: if split is None:
...@@ -67,12 +70,8 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs): ...@@ -67,12 +70,8 @@ def make_dataset_full(path, split, args, create_dataset_function, **kwargs):
ds = ConcatDataset(ds) ds = ConcatDataset(ds)
if should_split(split): if should_split(split):
ds = split_ds(ds, split, block_size=args.block_size) ds = split_ds(ds, split, block_size=args.block_size)
else: elif random_mapping:
ds = RandomMappingDataset(ds) ds = RandomMappingDataset(ds)
# if should_split(split):
# ds = split_ds(ds, split) # Large dataset, cannot shuffle, randomly mapping
# FIXME this will merge valid set and train set.
return ds return ds
def make_loaders(args, create_dataset_function): def make_loaders(args, create_dataset_function):
...@@ -115,25 +114,25 @@ def make_loaders(args, create_dataset_function): ...@@ -115,25 +114,25 @@ def make_loaders(args, create_dataset_function):
# make training and val dataset if necessary # make training and val dataset if necessary
if valid is None and args.valid_data is not None: if valid is None and args.valid_data is not None:
eval_set_args['path'] = args.valid_data eval_set_args['path'] = args.valid_data
valid = make_dataset(**eval_set_args, args=args) valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
if test is None and args.test_data is not None: if test is None and args.test_data is not None:
eval_set_args['path'] = args.test_data eval_set_args['path'] = args.test_data
test = make_dataset(**eval_set_args, args=args) test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
# wrap datasets with data loader # wrap datasets with data loader
if train is not None and args.batch_size > 0: if train is not None and args.batch_size > 0:
train = make_data_loader(train, batch_size, args.train_iters, args) train = make_data_loader(train, batch_size, args)
args.do_train = True args.do_train = True
else: else:
args.do_train = False args.do_train = False
eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
if valid is not None: if valid is not None:
valid = make_data_loader(valid, eval_batch_size, args.train_iters, args) valid = make_data_loader(valid, eval_batch_size, args)
args.do_valid = True args.do_valid = True
else: else:
args.do_valid = False args.do_valid = False
if test is not None: if test is not None:
test = make_data_loader(test, eval_batch_size, len(test) // eval_batch_size + 1, args) test = make_data_loader(test, eval_batch_size, args)
args.do_test = True args.do_test = True
else: else:
args.do_test = False args.do_test = False
......
...@@ -65,3 +65,18 @@ class BinaryDataset(Dataset): ...@@ -65,3 +65,18 @@ class BinaryDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
return self.process_fn(self.bin[index]) return self.process_fn(self.bin[index])
class TSVDataset(Dataset):
def __init__(self, path, process_fn, with_heads=True, **kwargs):
self.process_fn = process_fn
with open(path, 'r') as fin:
if with_heads:
self.heads = fin.readline().split('\t')
else:
self.heads = None
self.items = [line.split('\t') for line in fin]
def __len__(self):
return len(self.items)
def __getitem__(self, index):
return self.process_fn(self.items[index])
...@@ -56,7 +56,8 @@ def filling_sequence( ...@@ -56,7 +56,8 @@ def filling_sequence(
max_memory_length=100000, max_memory_length=100000,
log_attention_weights=None, log_attention_weights=None,
get_masks_and_position_ids=get_masks_and_position_ids_default, get_masks_and_position_ids=get_masks_and_position_ids_default,
mems=None mems=None,
**kw_args
): ):
''' '''
seq: [2, 3, 5, ..., -1(to be generated), -1, ...] seq: [2, 3, 5, ..., -1(to be generated), -1, ...]
...@@ -104,7 +105,8 @@ def filling_sequence( ...@@ -104,7 +105,8 @@ def filling_sequence(
position_ids[..., index: counter+1], position_ids[..., index: counter+1],
attention_mask[..., index: counter+1, :counter+1], # TODO memlen attention_mask[..., index: counter+1, :counter+1], # TODO memlen
mems=mems, mems=mems,
log_attention_weights=log_attention_weights_part log_attention_weights=log_attention_weights_part,
**kw_args
) )
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1 counter += 1
......
...@@ -51,7 +51,7 @@ def get_tokenizer(args=None, outer_tokenizer=None): ...@@ -51,7 +51,7 @@ def get_tokenizer(args=None, outer_tokenizer=None):
from .cogview import UnifiedTokenizer from .cogview import UnifiedTokenizer
get_tokenizer.tokenizer = UnifiedTokenizer( get_tokenizer.tokenizer = UnifiedTokenizer(
args.img_tokenizer_path, args.img_tokenizer_path,
# txt_tokenizer_type=args.tokenizer_type, txt_tokenizer_type='cogview',
device=torch.cuda.current_device() device=torch.cuda.current_device()
) )
elif args.tokenizer_type.startswith('glm'): elif args.tokenizer_type.startswith('glm'):
......
...@@ -22,6 +22,8 @@ python setup.py install ...@@ -22,6 +22,8 @@ python setup.py install
PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)), PRETRAINED_MODEL_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
'embed_assets', 'chinese_sentencepiece/cog-pretrain.model') 'embed_assets', 'chinese_sentencepiece/cog-pretrain.model')
PRETRAINED_MODEL_FILE_ICE = os.path.join(os.path.dirname(os.path.dirname(__file__)),
'embed_assets', 'chinese_sentencepiece/ice.model') # merge xlnet 3,2000 En tokens
def get_pairs(word): def get_pairs(word):
...@@ -148,5 +150,10 @@ def get_encoder(encoder_file, bpe_file): ...@@ -148,5 +150,10 @@ def get_encoder(encoder_file, bpe_file):
) )
def from_pretrained(): def from_pretrained(tokenizer_type='cogview'):
if tokenizer_type == 'cogview_ICE':
return get_encoder(PRETRAINED_MODEL_FILE_ICE, "")
elif tokenizer_type == 'cogview':
return get_encoder(PRETRAINED_MODEL_FILE, "") return get_encoder(PRETRAINED_MODEL_FILE, "")
else:
raise ValueError('Unknown cogview tokenizer.')
\ No newline at end of file
...@@ -20,10 +20,10 @@ from .sp_tokenizer import from_pretrained ...@@ -20,10 +20,10 @@ from .sp_tokenizer import from_pretrained
from .vqvae_tokenizer import VQVAETokenizer, sqrt_int from .vqvae_tokenizer import VQVAETokenizer, sqrt_int
class UnifiedTokenizer(object): class UnifiedTokenizer(object):
def __init__(self, img_tokenizer_path, device): def __init__(self, img_tokenizer_path, txt_tokenizer_type, device):
self.device = device self.device = device
self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device) self.img_tokenizer = VQVAETokenizer(model_path=img_tokenizer_path, device=self.device)
self.txt_tokenizer = from_pretrained() self.txt_tokenizer = from_pretrained(txt_tokenizer_type)
self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens self.num_tokens = self.img_tokenizer.num_tokens + self.txt_tokenizer.num_tokens
self.raw_command_tokens = [ self.raw_command_tokens = [
('[PAD]', 0), ('[PAD]', 0),
......
...@@ -100,14 +100,6 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio ...@@ -100,14 +100,6 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
if val_data is not None: if val_data is not None:
start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval start_iter_val = (args.train_iters // args.save_interval) * args.eval_interval
val_data.batch_sampler.start_iter = start_iter_val % len(val_data) val_data.batch_sampler.start_iter = start_iter_val % len(val_data)
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
# init hook before training # init hook before training
if hooks['init_function'] is not None: if hooks['init_function'] is not None:
...@@ -122,15 +114,11 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio ...@@ -122,15 +114,11 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_) save_checkpoint(args_.iteration, model_, optimizer_, lr_scheduler_, args_)
iteration, skipped = train(model, optimizer, iteration, skipped = train(model, optimizer,
lr_scheduler, lr_scheduler,
train_data_iterator, train_data,
val_data_iterator, val_data,
timers, args, summary_writer=summary_writer, timers, args, summary_writer=summary_writer,
hooks=hooks hooks=hooks
) )
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, timers, False, hooks=hooks)
# final save # final save
if args.save and iteration != 0: # TODO save if args.save and iteration != 0: # TODO save
...@@ -140,7 +128,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio ...@@ -140,7 +128,7 @@ def training_main(args, model_cls, forward_step_function, create_dataset_functio
if args.do_test and test_data is not None: if args.do_test and test_data is not None:
prefix = 'the end of training for test data' prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, iter(test_data), evaluate_and_print_results(prefix, iter(test_data),
model, args, timers, True, hooks=hooks) model, len(test_data) if args.strict_eval else args.eval_iters, args, timers, True, hooks=hooks)
def get_model(args, model_cls): def get_model(args, model_cls):
...@@ -258,12 +246,22 @@ def get_learning_rate_scheduler(optimizer, iteration, args, ...@@ -258,12 +246,22 @@ def get_learning_rate_scheduler(optimizer, iteration, args,
def train(model, optimizer, lr_scheduler, def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, train_data, val_data, timers, args,
summary_writer=None, hooks={}): summary_writer=None, hooks={}):
"""Train the model.""" """Train the model."""
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
# Turn on training mode which enables dropout. # Turn on training mode which enables dropout.
model.train() model.train()
# Tracking loss. # Tracking loss.
total_lm_loss = 0.0 total_lm_loss = 0.0
total_metrics = defaultdict(float) total_metrics = defaultdict(float)
...@@ -316,10 +314,14 @@ def train(model, optimizer, lr_scheduler, ...@@ -316,10 +314,14 @@ def train(model, optimizer, lr_scheduler,
# Evaluation # Evaluation
if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid: if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
if args.strict_eval:
val_data_iterator = iter(val_data)
eval_iters = len(val_data)
else:
eval_iters = args.eval_iters
prefix = 'iteration {}'.format(args.iteration) prefix = 'iteration {}'.format(args.iteration)
evaluate_and_print_results( evaluate_and_print_results(
prefix, val_data_iterator, model, args, timers, False, step=args.iteration, prefix, val_data_iterator, model, eval_iters, args, timers, False, step=args.iteration, summary_writer=summary_writer, hooks=hooks)
summary_writer=summary_writer, hooks=hooks)
if args.exit_interval and args.iteration % args.exit_interval == 0: if args.exit_interval and args.iteration % args.exit_interval == 0:
torch.distributed.barrier() torch.distributed.barrier()
...@@ -412,21 +414,19 @@ def backward_step(optimizer, model, loss, args, timers): ...@@ -412,21 +414,19 @@ def backward_step(optimizer, model, loss, args, timers):
return return
def evaluate(data_iterator, model, eval_iters, args, timers, verbose=False, hooks={}):
def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
"""Evaluation.""" """Evaluation."""
forward_step = hooks['forward_step'] forward_step = hooks['forward_step']
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
total_lm_loss = 0 total_lm_loss, metrics_total = 0, {}
with torch.no_grad(): with torch.no_grad():
iteration = 0 iteration = 0
while iteration < args.eval_iters: while iteration < eval_iters:
iteration += 1 iteration += 1
if verbose and iteration % args.log_interval == 0: if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters)) print_rank_0('Evaluating iter {}/{}'.format(iteration, eval_iters))
# Forward evaluation. # Forward evaluation.
lm_loss, metrics = forward_step(data_iterator, model, args, timers) lm_loss, metrics = forward_step(data_iterator, model, args, timers)
'''when contiguous memory optimizations are enabled, the buffers '''when contiguous memory optimizations are enabled, the buffers
...@@ -436,27 +436,31 @@ def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}): ...@@ -436,27 +436,31 @@ def evaluate(data_iterator, model, args, timers, verbose=False, hooks={}):
if args.deepspeed and args.deepspeed_activation_checkpointing: if args.deepspeed and args.deepspeed_activation_checkpointing:
deepspeed.checkpointing.reset() deepspeed.checkpointing.reset()
total_lm_loss += lm_loss.data.detach().float().item() total_lm_loss += lm_loss.data.detach().float().item()
for name in metrics:
if name not in metrics_total:
metrics_total[name] = 0.0
metrics_total[name] += metrics[name]
# Move model back to the train mode. # Move model back to the train mode.
model.train() model.train()
total_lm_loss /= args.eval_iters total_lm_loss /= eval_iters
return total_lm_loss metrics_avg = {key: value / eval_iters for key, value in metrics_total.items()}
return total_lm_loss, metrics_avg
def evaluate_and_print_results(prefix, data_iterator, model, eval_iters,
def evaluate_and_print_results(prefix, data_iterator, model,
args, timers, verbose=False, step=None, summary_writer=None, hooks={}): args, timers, verbose=False, step=None, summary_writer=None, hooks={}):
"""Helper function to evaluate and dump results on screen.""" """Helper function to evaluate and dump results on screen."""
# import line_profiler # import line_profiler
# profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward) # profile = line_profiler.LineProfiler(model.module.module.transformer.layers[0].forward)
# profile.enable() # profile.enable()
# torch.cuda.empty_cache() # torch.cuda.empty_cache()
lm_loss = evaluate(data_iterator, model, args, timers, verbose, hooks=hooks) lm_loss, metrics = evaluate(data_iterator, model, eval_iters, args, timers, verbose, hooks=hooks)
# profile.disable() # profile.disable()
# import sys # import sys
# profile.print_stats(sys.stdout) # profile.print_stats(sys.stdout)
lm_ppl = math.exp(min(20, lm_loss)) lm_ppl = math.exp(min(20, lm_loss))
report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step) report_evaluate_metrics(summary_writer, prefix, lm_loss, lm_ppl, step, metrics)
return lm_loss return lm_loss
...@@ -480,10 +484,12 @@ def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time, ...@@ -480,10 +484,12 @@ def report_iteration_metrics(summary_writer, optimizer, lr, loss, elapsed_time,
summary_writer.add_scalar('Train/'+key, avg_metrics[key], step) summary_writer.add_scalar('Train/'+key, avg_metrics[key], step)
def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step): def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step, avg_metrics):
string = ' validation loss at {} | '.format(prefix) string = ' validation loss at {} | '.format(prefix)
string += 'LM loss: {:.6E} | '.format(loss) string += 'LM loss: {:.6E} | '.format(loss)
string += 'LM PPL: {:.6E}'.format(ppl) string += 'LM PPL: {:.6E}'.format(ppl)
for key in avg_metrics:
string += ' {} {:.6E} |'.format(key, avg_metrics[key])
length = len(string) + 1 length = len(string) + 1
print_rank_0('-' * 100) print_rank_0('-' * 100)
print_rank_0('-' * length) print_rank_0('-' * length)
...@@ -492,6 +498,8 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step): ...@@ -492,6 +498,8 @@ def report_evaluate_metrics(summary_writer, prefix, loss, ppl, step):
if summary_writer is not None: if summary_writer is not None:
summary_writer.add_scalar(f'Train/valid_ppl', ppl, step) summary_writer.add_scalar(f'Train/valid_ppl', ppl, step)
summary_writer.add_scalar(f'Train/valid_loss', loss, step) summary_writer.add_scalar(f'Train/valid_loss', loss, step)
for key in avg_metrics:
summary_writer.add_scalar('Train/valid_'+key, avg_metrics[key], step)
''' '''
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment